From 6da8bf6f162d0e5c5a12e89eb72cc5e7eb0fdb3a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 27 Jul 2025 21:03:24 -0700 Subject: [PATCH 01/20] Initial commit with experimental code --- .../core/session/onnxruntime_c_api.h | 14 ++++ .../core/framework/tensorprotoutils.cc | 26 ++++++ onnxruntime/core/graph/abi_graph_types.h | 15 ++++ onnxruntime/core/graph/ep_api_types.cc | 39 +++++++++ onnxruntime/core/graph/ep_api_types.h | 4 + .../core/graph/model_editor_api_types.h | 5 ++ onnxruntime/core/session/onnxruntime_c_api.cc | 80 +++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 2 + 8 files changed, 185 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2f0e4aa7ce108..82a384967e689 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -275,6 +275,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, ORT_OP_ATTR_GRAPH, + ORT_OP_ATTR_TENSOR, } OrtOpAttrType; //! @} @@ -6052,6 +6053,19 @@ struct OrtApi { ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute); + /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. + * + * \param[in] node The OrtNode instance. + * \param[in] attribute The OrtOpAttr instance. + * \param[out] name Output parameter set to the attribute's name. The name is a null-terminated string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_ const OrtValue** attr_tensor); + /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * * \param[in] attribute The OrtOpAttr instance. diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ff440b595e499..966f52aef03b1 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1372,6 +1372,32 @@ Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& model_ return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, nullptr, alloc, value); } +/* +Status TensorProtoToOrtValue(const onnx::TensorProto& tensor_proto, OrtMemoryInfo* mem_info, OrtValue& value) { + // Get shape + std::vector shape(tensor_proto.dims().begin(), tensor_proto.dims().end()); + size_t num_elements = 1; + for (auto d : shape) num_elements *= d; + + // find raw data in proto buf + void* raw_data = nullptr; + SafeInt raw_data_len = 0; + if (utils::HasRawData(tensor_proto)) { + raw_data = const_cast(tensor_proto.raw_data().data()); + raw_data_len = tensor_proto.raw_data().size(); + } + + // Wrap with CreateTensorWithDataAsOrtValue + return Ort::Value::CreateTensor( + mem_info, + data.data(), + data.size() * sizeof(float), + shape.data(), + shape.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); +} +*/ + #define CASE_TYPE(X) \ case ONNX_NAMESPACE::TensorProto_DataType_##X: \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##X; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 6383d29d7a2bc..6517de064a478 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -156,6 +156,13 @@ struct OrtOpAttr { ONNX_NAMESPACE::AttributeProto attr_proto; }; +/// +/// Public type that represents an ONNX tensor. Currently, an OrtTensor is interchangeable with TensorProto. +/// +struct OrtTensor { + ONNX_NAMESPACE::TensorProto tensor_proto; +}; + /// /// Public type that represents an ONNX node. /// @@ -258,6 +265,14 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetNumSubgraphs(size_t& num_subgraphs) const = 0; + /// + /// Gets the node's attributes. + /// + /// Buffer into which to copy the attributes. + /// A status indicating success or an error. + virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, + const OrtValue*& attr_tensor) const = 0; + /// /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 4ceadb6191a9b..42f17bdd661e1 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -113,12 +113,33 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, if (node_attrs.size() > 0) { ep_node_attributes.reserve(node_attrs.size()); + std::unordered_map> tensor_attribute_values; for (const auto& item : node_attrs) { auto attr = std::make_unique(item.second); // Copy AttributeProto and owned by this EpNode object. + + // Create and cache an OrtValue for the 'TENSOR' attribute + if (attr->type() == onnx::AttributeProto::TENSOR) { + const auto& graph_viewer = ep_graph->GetGraphViewer(); + const auto& tensor_proto = reinterpret_cast(attr.get())->t(); + + // Initialize OrtValue for tensor attribute. + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + tensor_attribute_values.emplace(tensor_proto.name(), std::move(tensor_attribute_value)); + } + ep_node_attributes.push_back(reinterpret_cast(attr.get())); ep_node_attributes_map.emplace(item.first, std::move(attr)); } + + if (!tensor_attribute_values.empty()) { + ep_node->tensor_attribute_values_ = std ::move(tensor_attribute_values); + } } std::vector ep_node_subgraphs; @@ -230,6 +251,24 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& result) const { + const auto attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + result = nullptr; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& it = tensor_attribute_values_.find(attr_proto->name()); + if (it != tensor_attribute_values_.end()) { + result = it->second.get(); + Status::OK(); + } + + result = nullptr; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get 'TENSOR' attribute with the name ", attr_proto->name()); +} + Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 243bdc2944ffb..92d149b782f4c 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,6 +183,9 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, + const OrtValue*& attr_tensor) const override; + // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; @@ -227,6 +230,7 @@ struct EpNode : public OrtNode { std::unordered_map> attributes_map_; std::vector attributes_; + std::unordered_map> tensor_attribute_values_; // The 'TENSOR' Attribute as an OrtValue std::vector implicit_inputs_; std::vector subgraphs_; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 5d84e48182bfe..ecefdb941cf8a 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,6 +137,11 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& attr_tensor) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting 'TENSOR' attribute OrtOpAttr for OrtNode"); + } + Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 27f81b18be0c9..9561acd23c16c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3015,6 +3015,29 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_ const OrtValue** attr_tensor) { + API_IMPL_BEGIN + if (attr_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); + } + + const EpNode* ep_node = EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); + } + + const auto& tensor_proto = reinterpret_cast(attribute)->t(); + + ORT_ENFORCE(utils::HasDataType(tensor_proto)); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) { API_IMPL_BEGIN const auto attr = attribute->attr_proto; @@ -3052,6 +3075,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + *type = OrtOpAttrType::ORT_OP_ATTR_TENSOR; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } @@ -3073,6 +3100,58 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetName, _In_ const OrtOpAttr* attribute, _O API_IMPL_END } +/* +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_ OrtValue** attr_tensor) { + API_IMPL_BEGIN + if (attr_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); + } + + const auto& tensor_proto = reinterpret_cast(attribute)->t(); + + ORT_ENFORCE(utils::HasDataType(tensor_proto)); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + + // Set up memory info for CPU + OrtMemoryInfo* mem_info; + using MemoryInfoUniquePtr = std::unique_ptr>; + ORT_API_RETURN_IF_ERROR(OrtApis::CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info)); + auto unique_ptr_mem_info = MemoryInfoUniquePtr(mem_info, OrtApis::ReleaseMemoryInfo); + + //ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + // size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, + // _Outptr_ OrtValue** out); + + const auto tensor_type = static_cast(tensor_proto.data_type()); + //const void* const raw_data = utils::HasRawData(tensor_proto) ? tensor_proto.raw_data().data() : nullptr; + //const size_t raw_data_len = utils::HasRawData(tensor_proto) ? tensor_proto.raw_data().size() : 0; + + // Get shape + std::vector shape(tensor_proto.dims().begin(), tensor_proto.dims().end()); + size_t num_elements = 1; + for (auto d : shape) num_elements *= d; + + // find raw data in proto buf + void* raw_data = nullptr; + SafeInt raw_data_len = 0; + if (utils::HasRawData(tensor_proto)) { + raw_data = const_cast(tensor_proto.raw_data().data()); + raw_data_len = tensor_proto.raw_data().size(); + } + + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorWithDataAsOrtValue(mem_info, raw_data, raw_data_len, shape.data(), shape.size(), tensor_type, attr_tensor)); + + + return nullptr; + API_IMPL_END +} +*/ + ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs) { API_IMPL_BEGIN if (num_subgraphs == nullptr) { @@ -4034,6 +4113,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, + &OrtApis::Node_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..29c24f16e6177 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -679,6 +679,8 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_ const OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); From 006e07d87c2aca6f4e01bb1916514f95d7b67493 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 09:40:20 -0700 Subject: [PATCH 02/20] Remove the code not needed --- .../core/framework/tensorprotoutils.cc | 26 ---------- onnxruntime/core/graph/abi_graph_types.h | 7 --- onnxruntime/core/session/onnxruntime_c_api.cc | 52 ------------------- 3 files changed, 85 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 966f52aef03b1..ff440b595e499 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1372,32 +1372,6 @@ Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& model_ return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, nullptr, alloc, value); } -/* -Status TensorProtoToOrtValue(const onnx::TensorProto& tensor_proto, OrtMemoryInfo* mem_info, OrtValue& value) { - // Get shape - std::vector shape(tensor_proto.dims().begin(), tensor_proto.dims().end()); - size_t num_elements = 1; - for (auto d : shape) num_elements *= d; - - // find raw data in proto buf - void* raw_data = nullptr; - SafeInt raw_data_len = 0; - if (utils::HasRawData(tensor_proto)) { - raw_data = const_cast(tensor_proto.raw_data().data()); - raw_data_len = tensor_proto.raw_data().size(); - } - - // Wrap with CreateTensorWithDataAsOrtValue - return Ort::Value::CreateTensor( - mem_info, - data.data(), - data.size() * sizeof(float), - shape.data(), - shape.size(), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); -} -*/ - #define CASE_TYPE(X) \ case ONNX_NAMESPACE::TensorProto_DataType_##X: \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##X; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 6517de064a478..855809b8856a4 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -156,13 +156,6 @@ struct OrtOpAttr { ONNX_NAMESPACE::AttributeProto attr_proto; }; -/// -/// Public type that represents an ONNX tensor. Currently, an OrtTensor is interchangeable with TensorProto. -/// -struct OrtTensor { - ONNX_NAMESPACE::TensorProto tensor_proto; -}; - /// /// Public type that represents an ONNX node. /// diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9561acd23c16c..a976da725d044 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3100,58 +3100,6 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetName, _In_ const OrtOpAttr* attribute, _O API_IMPL_END } -/* -ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_ OrtValue** attr_tensor) { - API_IMPL_BEGIN - if (attr_tensor == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); - } - if (attribute == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); - } - - const auto& tensor_proto = reinterpret_cast(attribute)->t(); - - ORT_ENFORCE(utils::HasDataType(tensor_proto)); - ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())); - ORT_ENFORCE(!utils::HasExternalData(tensor_proto), - "Tensor proto with external data for value attribute is not supported."); - - // Set up memory info for CPU - OrtMemoryInfo* mem_info; - using MemoryInfoUniquePtr = std::unique_ptr>; - ORT_API_RETURN_IF_ERROR(OrtApis::CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info)); - auto unique_ptr_mem_info = MemoryInfoUniquePtr(mem_info, OrtApis::ReleaseMemoryInfo); - - //ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, - // size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, - // _Outptr_ OrtValue** out); - - const auto tensor_type = static_cast(tensor_proto.data_type()); - //const void* const raw_data = utils::HasRawData(tensor_proto) ? tensor_proto.raw_data().data() : nullptr; - //const size_t raw_data_len = utils::HasRawData(tensor_proto) ? tensor_proto.raw_data().size() : 0; - - // Get shape - std::vector shape(tensor_proto.dims().begin(), tensor_proto.dims().end()); - size_t num_elements = 1; - for (auto d : shape) num_elements *= d; - - // find raw data in proto buf - void* raw_data = nullptr; - SafeInt raw_data_len = 0; - if (utils::HasRawData(tensor_proto)) { - raw_data = const_cast(tensor_proto.raw_data().data()); - raw_data_len = tensor_proto.raw_data().size(); - } - - ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorWithDataAsOrtValue(mem_info, raw_data, raw_data_len, shape.data(), shape.size(), tensor_type, attr_tensor)); - - - return nullptr; - API_IMPL_END -} -*/ - ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs) { API_IMPL_BEGIN if (num_subgraphs == nullptr) { From 7a2722e2d1a8cfb433e97f04de13f4064f458e8c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 13:03:06 -0700 Subject: [PATCH 03/20] update OrtGraph to proto util to support getting TENSOR attribute --- .../core/providers/utils/ort_graph_to_proto.h | 91 ++++++++++++++++++- 1 file changed, 88 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 0d920ab7dac89..52622bda91962 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -758,6 +758,91 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr break; } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + std::string name = std::string(attr_name) + "_tensor_proto"; + tensor_proto.set_name(name); + tensor_proto.add_dims(2); + tensor_proto.add_dims(3); + + const OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + + // Get tensor type and shape info + OrtTensorTypeAndShapeInfo* type_shape_info; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(ort_value, &type_shape_info)); + + // Get tensor type + ONNXTensorElementDataType element_type; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape_info, &element_type)); + + // Set tensor type + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + // Get rank + size_t num_dims; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape_info, &num_dims)); + + // Get dimensions + std::vector dims(num_dims); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(type_shape_info, dims.data(), num_dims)); + + // Set dimensions + for (auto& dim : dims) { + tensor_proto.add_dims(dim); + } + + const void* data = nullptr; + size_t data_bytes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); + + ort_api.ReleaseTensorTypeAndShapeInfo(type_shape_info); + + } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); return Ort::Status(err_msg.c_str(), ORT_FAIL); From fe025563db94fe129ecb29da588200c9bec76c2c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 13:03:28 -0700 Subject: [PATCH 04/20] small update --- onnxruntime/core/session/onnxruntime_c_api.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a976da725d044..1976c746da71f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3035,6 +3035,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())); ORT_ENFORCE(!utils::HasExternalData(tensor_proto), "Tensor proto with external data for value attribute is not supported."); + + ORT_API_RETURN_IF_STATUS_NOT_OK(ep_node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); API_IMPL_END } From 209122cfad9ff99376ab5937005ca7e56dcc6696 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 13:24:42 -0700 Subject: [PATCH 05/20] fix minor issues --- include/onnxruntime/core/providers/utils/ort_graph_to_proto.h | 1 - onnxruntime/core/graph/ep_api_types.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 52622bda91962..8de30dc369b56 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -841,7 +841,6 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or tensor_proto.set_raw_data(data, data_bytes); ort_api.ReleaseTensorTypeAndShapeInfo(type_shape_info); - } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 42f17bdd661e1..8adb79ced9bce 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -262,7 +262,7 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or const auto& it = tensor_attribute_values_.find(attr_proto->name()); if (it != tensor_attribute_values_.end()) { result = it->second.get(); - Status::OK(); + return Status::OK(); } result = nullptr; From af17a249bfc393896f568e1bc3a180456e7641b0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 13:42:14 -0700 Subject: [PATCH 06/20] fix unused parameter warning --- onnxruntime/core/graph/model_editor_api_types.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index ecefdb941cf8a..676240ad0206b 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,7 +137,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& attr_tensor) const override { + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, const OrtValue*& /*attr_tensor*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting 'TENSOR' attribute OrtOpAttr for OrtNode"); } From 2ab28baa5f3e24362563ec912330970b818af04c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 14:13:34 -0700 Subject: [PATCH 07/20] minor change to change the place the function declaration --- onnxruntime/core/graph/abi_graph_types.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 855809b8856a4..b04db9781ea40 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -251,13 +251,6 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; - /// - /// Gets the number of node subgraphs. - /// - /// Output parameter set to the number of subgraphs. - /// A status indicating success or an error. - virtual onnxruntime::Status GetNumSubgraphs(size_t& num_subgraphs) const = 0; - /// /// Gets the node's attributes. /// @@ -266,6 +259,13 @@ struct OrtNode { virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& attr_tensor) const = 0; + /// + /// Gets the number of node subgraphs. + /// + /// Output parameter set to the number of subgraphs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNumSubgraphs(size_t& num_subgraphs) const = 0; + /// /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// From b5132079be09103d5a2bbbd5bf53029d3f160958 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 14:21:48 -0700 Subject: [PATCH 08/20] forogt to add 'return nullptr' --- onnxruntime/core/session/onnxruntime_c_api.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1976c746da71f..7cb87bb18b80a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3037,6 +3037,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo "Tensor proto with external data for value attribute is not supported."); ORT_API_RETURN_IF_STATUS_NOT_OK(ep_node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + return nullptr; API_IMPL_END } From 05d47be9e630420e6b9a9a673cd854ef8104f684 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 28 Jul 2025 15:42:48 -0700 Subject: [PATCH 09/20] remove wrong comment and code --- include/onnxruntime/core/providers/utils/ort_graph_to_proto.h | 2 -- include/onnxruntime/core/session/onnxruntime_c_api.h | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 8de30dc369b56..880b6edce9744 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -764,8 +764,6 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or onnx::TensorProto tensor_proto; std::string name = std::string(attr_name) + "_tensor_proto"; tensor_proto.set_name(name); - tensor_proto.add_dims(2); - tensor_proto.add_dims(3); const OrtValue* ort_value = nullptr; ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 82a384967e689..916dde6be7deb 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6057,7 +6057,8 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. - * \param[out] name Output parameter set to the attribute's name. The name is a null-terminated string. + * \param[out] attr_tensor Output parameter set to the 'TENSOR' attribute value or NULL. Do not cache the OrtValue + * as it is released when the owning OrtGraph is released. * * \snippet{doc} snippets.dox OrtStatus Return Value * From 287d5b5fdc3172943704d0a94a4d0e5415624468 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 00:14:49 -0700 Subject: [PATCH 10/20] update code and add unit test --- .../core/providers/utils/ort_graph_to_proto.h | 46 +++++---- onnxruntime/core/graph/ep_api_types.cc | 54 +++++++---- onnxruntime/core/graph/ep_api_types.h | 2 +- onnxruntime/test/ep_graph/test_ep_graph.cc | 96 +++++++++++++++++++ 4 files changed, 164 insertions(+), 34 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 880b6edce9744..744b012239e94 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -761,9 +761,9 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto tensor_proto; - std::string name = std::string(attr_name) + "_tensor_proto"; - tensor_proto.set_name(name); + std::unique_ptr tensor_proto = std::make_unique(); + + // TensorProto as an attribute value doesn't require a name. const OrtValue* ort_value = nullptr; ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); @@ -779,37 +779,48 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // Set tensor type switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); + break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); + break; } default: { std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); @@ -827,7 +838,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // Set dimensions for (auto& dim : dims) { - tensor_proto.add_dims(dim); + tensor_proto->add_dims(dim); } const void* data = nullptr; @@ -836,9 +847,12 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); // Copy the Ortvalue to TensorProto as raw data - tensor_proto.set_raw_data(data, data_bytes); + tensor_proto->set_raw_data(data, data_bytes); ort_api.ReleaseTensorTypeAndShapeInfo(type_shape_info); + + *(attr_proto.mutable_t()) = *tensor_proto; // Copy TensorProto into attribute + break; } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 8adb79ced9bce..006c8b9f1ba1d 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -110,36 +110,45 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, const auto& node_attrs = node.GetAttributes(); std::unordered_map> ep_node_attributes_map; std::vector ep_node_attributes; + std::unordered_map> tensor_attribute_values; + + auto tensor_proto_to_ort_value = [&](const ONNX_NAMESPACE::TensorProto& tensor_proto, + std::string& tensor_proto_name) -> Status { + const auto& graph_viewer = ep_graph->GetGraphViewer(); + + // Initialize OrtValue for tensor attribute. + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + tensor_attribute_values.emplace(tensor_proto_name, std::move(tensor_attribute_value)); + + return Status::OK(); + }; if (node_attrs.size() > 0) { ep_node_attributes.reserve(node_attrs.size()); - std::unordered_map> tensor_attribute_values; for (const auto& item : node_attrs) { auto attr = std::make_unique(item.second); // Copy AttributeProto and owned by this EpNode object. // Create and cache an OrtValue for the 'TENSOR' attribute if (attr->type() == onnx::AttributeProto::TENSOR) { - const auto& graph_viewer = ep_graph->GetGraphViewer(); - const auto& tensor_proto = reinterpret_cast(attr.get())->t(); - - // Initialize OrtValue for tensor attribute. - // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - tensor_attribute_values.emplace(tensor_proto.name(), std::move(tensor_attribute_value)); + const auto& tensor_proto = attr->t(); + // Some tensor proto could have no name. + // Create a name for that case since we need the name as the key for lookup later. + std::string tensor_proto_name = tensor_proto.name(); + if (tensor_proto.name().empty()) { + tensor_proto_name = node.Name() + "_" + attr->name(); + } + ORT_RETURN_IF_ERROR(tensor_proto_to_ort_value(tensor_proto, tensor_proto_name)); } ep_node_attributes.push_back(reinterpret_cast(attr.get())); ep_node_attributes_map.emplace(item.first, std::move(attr)); } - - if (!tensor_attribute_values.empty()) { - ep_node->tensor_attribute_values_ = std ::move(tensor_attribute_values); - } } std::vector ep_node_subgraphs; @@ -168,6 +177,11 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ep_node->inputs_ = std::move(ep_node_inputs); ep_node->outputs_ = std::move(ep_node_outputs); ep_node->attributes_map_ = std::move(ep_node_attributes_map); + + if (!tensor_attribute_values.empty()) { + ep_node->tensor_attribute_values_ = std ::move(tensor_attribute_values); + } + ep_node->attributes_ = std::move(ep_node_attributes); ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs); ep_node->subgraphs_ = std::move(ep_node_subgraphs); @@ -259,7 +273,13 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); } - const auto& it = tensor_attribute_values_.find(attr_proto->name()); + const auto& tensor_proto = attr_proto->t(); + std::string tensor_proto_name = tensor_proto.name(); + if (tensor_proto.name().empty()) { + tensor_proto_name = node_.Name() + "_" + attr_proto->name(); + } + + const auto& it = tensor_attribute_values_.find(tensor_proto_name); if (it != tensor_attribute_values_.end()) { result = it->second.get(); return Status::OK(); diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 92d149b782f4c..a05e80e20add5 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -230,7 +230,7 @@ struct EpNode : public OrtNode { std::unordered_map> attributes_map_; std::vector attributes_; - std::unordered_map> tensor_attribute_values_; // The 'TENSOR' Attribute as an OrtValue + std::unordered_map> tensor_attribute_values_; // The 'TENSOR' Attribute as an OrtValue std::vector implicit_inputs_; std::vector subgraphs_; diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 45314f8f39eea..93601f0358f8a 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -220,6 +220,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& outpu output_data.assign(output_values, output_values + num_output_elems); } +static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {3}; + std::vector input_data = {2, 3, 4}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'x' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("x"); + + // Run session and get outputs + std::array output_names{"y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 24); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + // Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. // Checks that the outputs of the serialized and original models are identical. TEST(EpGraphTest, SerializeToProto_Mnist) { @@ -350,6 +383,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } } +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_ConstantOfShape) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "constant_of_shape_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConstantOfShapeModel(original_model_path, output_original); + RunConstantOfShapeModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -892,6 +984,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); From d2c0695eaca680b7a4404e2b29d5a5d23a2ee2a6 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 14:46:27 -0700 Subject: [PATCH 11/20] address reviewers' comments --- .../core/providers/utils/ort_graph_to_proto.h | 43 +++++++++--------- .../core/session/onnxruntime_c_api.h | 5 ++- onnxruntime/core/graph/abi_graph_types.h | 10 +++-- onnxruntime/core/graph/ep_api_types.cc | 44 +++++++++++-------- onnxruntime/core/graph/ep_api_types.h | 4 ++ .../core/graph/model_editor_api_types.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 11 ++--- 7 files changed, 64 insertions(+), 55 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 744b012239e94..a0f879528a404 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -768,58 +768,69 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or const OrtValue* ort_value = nullptr; ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + Ort::ConstValue tensor(ort_value); + // Get tensor type and shape info - OrtTensorTypeAndShapeInfo* type_shape_info; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(ort_value, &type_shape_info)); + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); // Get tensor type - ONNXTensorElementDataType element_type; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape_info, &element_type)); + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); - // Set tensor type + size_t element_size = 0; switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); + element_size = sizeof(float); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); + element_size = sizeof(uint8_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); + element_size = sizeof(int8_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); + element_size = sizeof(uint16_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); + element_size = sizeof(int16_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); + element_size = sizeof(int32_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + element_size = sizeof(int64_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); + element_size = sizeof(bool); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); + element_size = sizeof(double); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); + element_size = sizeof(uint32_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); + element_size = sizeof(uint64_t); break; } default: { @@ -828,30 +839,20 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or } } - // Get rank - size_t num_dims; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape_info, &num_dims)); - - // Get dimensions - std::vector dims(num_dims); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(type_shape_info, dims.data(), num_dims)); + auto shape = type_shape_info.GetShape(); - // Set dimensions - for (auto& dim : dims) { + for (auto& dim : shape) { tensor_proto->add_dims(dim); } - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + size_t element_count = type_shape_info.GetElementCount(); + size_t data_bytes = element_count * element_size; + const void* data = tensor.GetTensorData(); // Copy the Ortvalue to TensorProto as raw data tensor_proto->set_raw_data(data, data_bytes); - ort_api.ReleaseTensorTypeAndShapeInfo(type_shape_info); - - *(attr_proto.mutable_t()) = *tensor_proto; // Copy TensorProto into attribute + *(attr_proto.mutable_t()) = std::move(*tensor_proto); // move assignment break; } default: { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 916dde6be7deb..f2078f524a144 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6057,8 +6057,9 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. - * \param[out] attr_tensor Output parameter set to the 'TENSOR' attribute value or NULL. Do not cache the OrtValue - * as it is released when the owning OrtGraph is released. + * \param[out] attr_tensor Output parameter set to the 'TENSOR' attribute value or nullptr + * if it's not a 'TENSOR' attribute. Do not cache the OrtValue as + * it is released when the owning OrtGraph is released. * * \snippet{doc} snippets.dox OrtStatus Return Value * diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index b04db9781ea40..984e2cecaa1cc 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -252,12 +252,14 @@ struct OrtNode { virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; /// - /// Gets the node's attributes. + /// Gets the node's 'TENSOR' attribute as an OrtValue. /// - /// Buffer into which to copy the attributes. + /// Node's 'TENSOR' attribute. + /// Output parameter set to the 'TENSOR' attribute value or nullptr + /// if it's not a 'TENSOR' attribute. /// A status indicating success or an error. - virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - const OrtValue*& attr_tensor) const = 0; + virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, + const OrtValue*& value) const = 0; /// /// Gets the number of node subgraphs. diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 006c8b9f1ba1d..c1ed72c56e554 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -113,7 +113,7 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, std::unordered_map> tensor_attribute_values; auto tensor_proto_to_ort_value = [&](const ONNX_NAMESPACE::TensorProto& tensor_proto, - std::string& tensor_proto_name) -> Status { + const std::string& tensor_proto_name) -> Status { const auto& graph_viewer = ep_graph->GetGraphViewer(); // Initialize OrtValue for tensor attribute. @@ -136,13 +136,11 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, // Create and cache an OrtValue for the 'TENSOR' attribute if (attr->type() == onnx::AttributeProto::TENSOR) { - const auto& tensor_proto = attr->t(); // Some tensor proto could have no name. - // Create a name for that case since we need the name as the key for lookup later. - std::string tensor_proto_name = tensor_proto.name(); - if (tensor_proto.name().empty()) { - tensor_proto_name = node.Name() + "_" + attr->name(); - } + // Create a name for that case since we need a unique name as the key for lookup later. + std::string tensor_proto_name = ep_node->GetUniqueTensorAttributeName(reinterpret_cast(attr.get())); + const auto& tensor_proto = attr->t(); + ORT_RETURN_IF_ERROR(tensor_proto_to_ort_value(tensor_proto, tensor_proto_name)); } @@ -177,11 +175,7 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ep_node->inputs_ = std::move(ep_node_inputs); ep_node->outputs_ = std::move(ep_node_outputs); ep_node->attributes_map_ = std::move(ep_node_attributes_map); - - if (!tensor_attribute_values.empty()) { - ep_node->tensor_attribute_values_ = std ::move(tensor_attribute_values); - } - + ep_node->tensor_attribute_values_ = std::move(tensor_attribute_values); ep_node->attributes_ = std::move(ep_node_attributes); ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs); ep_node->subgraphs_ = std::move(ep_node_subgraphs); @@ -266,18 +260,14 @@ Status EpNode::GetAttributes(gsl::span dst) const { } Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& result) const { - const auto attr_proto = reinterpret_cast(attribute); + const auto* attr_proto = reinterpret_cast(attribute); if (attr_proto->type() != onnx::AttributeProto::TENSOR) { result = nullptr; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); } - const auto& tensor_proto = attr_proto->t(); - std::string tensor_proto_name = tensor_proto.name(); - if (tensor_proto.name().empty()) { - tensor_proto_name = node_.Name() + "_" + attr_proto->name(); - } + auto tensor_proto_name = GetUniqueTensorAttributeName(attribute); const auto& it = tensor_attribute_values_.find(tensor_proto_name); if (it != tensor_attribute_values_.end()) { @@ -286,7 +276,7 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or } result = nullptr; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get 'TENSOR' attribute with the name ", attr_proto->name()); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get 'TENSOR' attribute with the name ", tensor_proto_name); } Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { @@ -340,6 +330,22 @@ const std::string& EpNode::GetEpName() const { return node_.GetExecutionProviderType(); } +const std::string EpNode::GetUniqueTensorAttributeName(const OrtOpAttr* attr) const { + const auto* attr_proto = reinterpret_cast(attr); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return ""; + } + + const auto& tensor_proto = attr_proto->t(); + std::string tensor_proto_name = node_.Name() + "_" + attr_proto->name(); + if (!tensor_proto.name().empty()) { + tensor_proto_name += "_" + attr_proto->name(); + } + + return tensor_proto_name; +} + // // EpValueInfo // diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index a05e80e20add5..9654362c3a7f1 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -218,6 +218,10 @@ struct EpNode : public OrtNode { // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; + // Helper to get the unique name for the 'TENSOR' attribute. Returns empty string if + // attribute is not 'TENSOR' type. + const std::string GetUniqueTensorAttributeName(const OrtOpAttr* attr) const; + private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 676240ad0206b..7e3eabd49217a 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -139,7 +139,7 @@ struct ModelEditorNode : public OrtNode { Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, const OrtValue*& /*attr_tensor*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting 'TENSOR' attribute OrtOpAttr for OrtNode"); + "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); } Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7cb87bb18b80a..1ee79893a9731 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3024,19 +3024,14 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - const EpNode* ep_node = EpNode::ToInternal(node); - if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); - } - const auto& tensor_proto = reinterpret_cast(attribute)->t(); - ORT_ENFORCE(utils::HasDataType(tensor_proto)); - ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())); + ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); ORT_ENFORCE(!utils::HasExternalData(tensor_proto), "Tensor proto with external data for value attribute is not supported."); - ORT_API_RETURN_IF_STATUS_NOT_OK(ep_node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); return nullptr; API_IMPL_END } From 3a46c881d2a6225b8102e63e8f4c638b32d06d53 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 15:42:31 -0700 Subject: [PATCH 12/20] make it on demand for creating and caching the 'TENSOR' attribute as an OrtValue --- onnxruntime/core/graph/abi_graph_types.h | 2 +- onnxruntime/core/graph/ep_api_types.cc | 57 +++++++++++------------- onnxruntime/core/graph/ep_api_types.h | 2 +- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 984e2cecaa1cc..b14fea7e221d5 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -259,7 +259,7 @@ struct OrtNode { /// if it's not a 'TENSOR' attribute. /// A status indicating success or an error. virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, - const OrtValue*& value) const = 0; + const OrtValue*& value) = 0; /// /// Gets the number of node subgraphs. diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index c1ed72c56e554..98e12b9b684ae 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -110,40 +110,12 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, const auto& node_attrs = node.GetAttributes(); std::unordered_map> ep_node_attributes_map; std::vector ep_node_attributes; - std::unordered_map> tensor_attribute_values; - - auto tensor_proto_to_ort_value = [&](const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::string& tensor_proto_name) -> Status { - const auto& graph_viewer = ep_graph->GetGraphViewer(); - - // Initialize OrtValue for tensor attribute. - // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - tensor_attribute_values.emplace(tensor_proto_name, std::move(tensor_attribute_value)); - - return Status::OK(); - }; if (node_attrs.size() > 0) { ep_node_attributes.reserve(node_attrs.size()); for (const auto& item : node_attrs) { auto attr = std::make_unique(item.second); // Copy AttributeProto and owned by this EpNode object. - - // Create and cache an OrtValue for the 'TENSOR' attribute - if (attr->type() == onnx::AttributeProto::TENSOR) { - // Some tensor proto could have no name. - // Create a name for that case since we need a unique name as the key for lookup later. - std::string tensor_proto_name = ep_node->GetUniqueTensorAttributeName(reinterpret_cast(attr.get())); - const auto& tensor_proto = attr->t(); - - ORT_RETURN_IF_ERROR(tensor_proto_to_ort_value(tensor_proto, tensor_proto_name)); - } - ep_node_attributes.push_back(reinterpret_cast(attr.get())); ep_node_attributes_map.emplace(item.first, std::move(attr)); } @@ -175,7 +147,6 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ep_node->inputs_ = std::move(ep_node_inputs); ep_node->outputs_ = std::move(ep_node_outputs); ep_node->attributes_map_ = std::move(ep_node_attributes_map); - ep_node->tensor_attribute_values_ = std::move(tensor_attribute_values); ep_node->attributes_ = std::move(ep_node_attributes); ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs); ep_node->subgraphs_ = std::move(ep_node_subgraphs); @@ -259,7 +230,7 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } -Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& result) const { +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& result) { const auto* attr_proto = reinterpret_cast(attribute); if (attr_proto->type() != onnx::AttributeProto::TENSOR) { @@ -268,6 +239,7 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or } auto tensor_proto_name = GetUniqueTensorAttributeName(attribute); + assert(!tensor_proto_name.empty()); const auto& it = tensor_attribute_values_.find(tensor_proto_name); if (it != tensor_attribute_values_.end()) { @@ -275,8 +247,29 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or return Status::OK(); } - result = nullptr; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get 'TENSOR' attribute with the name ", tensor_proto_name); + auto tensor_proto_to_ort_value = [&](const ONNX_NAMESPACE::TensorProto& tensor_proto, + const OrtValue*& result) -> Status { + const auto& graph_viewer = ep_graph_->GetGraphViewer(); + + // Initialize OrtValue for tensor attribute. + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + result = tensor_attribute_value.get(); + tensor_attribute_values_.emplace(tensor_proto_name, std::move(tensor_attribute_value)); + + return Status::OK(); + }; + + const auto* attr_proto = reinterpret_cast(attribute); + const auto& tensor_proto = attr_proto->t(); + + // Create and cache an OrtValue for the 'TENSOR' attribute + ORT_RETURN_IF_ERROR(tensor_proto_to_ort_value(tensor_proto, result)); + return Status::OK(); } Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 9654362c3a7f1..74eb7b0c01416 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -184,7 +184,7 @@ struct EpNode : public OrtNode { Status GetAttributes(gsl::span attrs) const override; Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - const OrtValue*& attr_tensor) const override; + const OrtValue*& attr_tensor) override; // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; From f78d672b7ea42bacf4ae6678a4198cf5bc260c92 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 15:53:58 -0700 Subject: [PATCH 13/20] address lintrunner issue --- onnxruntime/core/graph/ep_api_types.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 267bc1600fe01..bc7c73c938f38 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -281,7 +281,7 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or return Status::OK(); }; - + const auto* attr_proto = reinterpret_cast(attribute); const auto& tensor_proto = attr_proto->t(); From 103168b674bc615c0fff98f6439bf09772c4bf67 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 15:57:26 -0700 Subject: [PATCH 14/20] Use onnx::TensorProto instead of unique_ptr --- .../core/providers/utils/ort_graph_to_proto.h | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index a0f879528a404..3de4e667a6880 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -761,7 +761,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); - std::unique_ptr tensor_proto = std::make_unique(); + onnx::TensorProto tensor_proto; // TensorProto as an attribute value doesn't require a name. @@ -779,57 +779,57 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or size_t element_size = 0; switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); element_size = sizeof(float); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); element_size = sizeof(uint8_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); element_size = sizeof(int8_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); element_size = sizeof(uint16_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); element_size = sizeof(int16_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); element_size = sizeof(int32_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); element_size = sizeof(int64_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); element_size = sizeof(bool); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); element_size = sizeof(double); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); element_size = sizeof(uint32_t); break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); element_size = sizeof(uint64_t); break; } @@ -842,7 +842,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or auto shape = type_shape_info.GetShape(); for (auto& dim : shape) { - tensor_proto->add_dims(dim); + tensor_proto.add_dims(dim); } size_t element_count = type_shape_info.GetElementCount(); @@ -850,9 +850,9 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or const void* data = tensor.GetTensorData(); // Copy the Ortvalue to TensorProto as raw data - tensor_proto->set_raw_data(data, data_bytes); + tensor_proto.set_raw_data(data, data_bytes); - *(attr_proto.mutable_t()) = std::move(*tensor_proto); // move assignment + *(attr_proto.mutable_t()) = tensor_proto; break; } default: { From ba84595d2e59e467115868507a886dd4bc9a21d5 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 16:10:56 -0700 Subject: [PATCH 15/20] remove redundnat code --- onnxruntime/core/graph/ep_api_types.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index bc7c73c938f38..16309b02a22fd 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -282,7 +282,6 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or return Status::OK(); }; - const auto* attr_proto = reinterpret_cast(attribute); const auto& tensor_proto = attr_proto->t(); // Create and cache an OrtValue for the 'TENSOR' attribute From a3b843cf5366767fc00b52df21978b92509c1020 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 16:17:55 -0700 Subject: [PATCH 16/20] fix compile error --- onnxruntime/core/graph/model_editor_api_types.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 7e3eabd49217a..a465f1a43755b 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,7 +137,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, const OrtValue*& /*attr_tensor*/) const override { + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, const OrtValue*& /*attr_tensor*/) override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); } From d635adb371ec5fddef009d86e362c1c12c8ce3d9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 17:47:58 -0700 Subject: [PATCH 17/20] Make the api return OrtValue without caching it, caller now owns it --- .../core/providers/utils/ort_graph_to_proto.h | 6 +-- .../core/session/onnxruntime_c_api.h | 5 +-- onnxruntime/core/graph/abi_graph_types.h | 2 +- onnxruntime/core/graph/ep_api_types.cc | 37 +++---------------- onnxruntime/core/graph/ep_api_types.h | 6 +-- .../core/graph/model_editor_api_types.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 4 +- onnxruntime/core/session/ort_apis.h | 2 +- 8 files changed, 16 insertions(+), 48 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 3de4e667a6880..21aa797ce16eb 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -765,10 +765,10 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // TensorProto as an attribute value doesn't require a name. - const OrtValue* ort_value = nullptr; + OrtValue* ort_value = nullptr; ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); - Ort::ConstValue tensor(ort_value); + Ort::Value tensor(ort_value); // Get tensor type and shape info Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); @@ -852,7 +852,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // Copy the Ortvalue to TensorProto as raw data tensor_proto.set_raw_data(data, data_bytes); - *(attr_proto.mutable_t()) = tensor_proto; + *(attr_proto.mutable_t()) = std::move(tensor_proto); break; } default: { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cc16c12e88f98..282ed8da5f295 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6071,15 +6071,14 @@ struct OrtApi { * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. * \param[out] attr_tensor Output parameter set to the 'TENSOR' attribute value or nullptr - * if it's not a 'TENSOR' attribute. Do not cache the OrtValue as - * it is released when the owning OrtGraph is released. + * if it's not a 'TENSOR' attribute. Must be freed with OrtApi::ReleaseValue. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, - _Outptr_ const OrtValue** attr_tensor); + _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index b14fea7e221d5..dd0e786b68bc2 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -259,7 +259,7 @@ struct OrtNode { /// if it's not a 'TENSOR' attribute. /// A status indicating success or an error. virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, - const OrtValue*& value) = 0; + OrtValue** value) const = 0; /// /// Gets the number of node subgraphs. diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 16309b02a22fd..b76daae2fd4ba 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -248,25 +248,16 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } -Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const OrtValue*& result) { +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue** result) const { const auto* attr_proto = reinterpret_cast(attribute); if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - result = nullptr; + *result = nullptr; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); } - auto tensor_proto_name = GetUniqueTensorAttributeName(attribute); - assert(!tensor_proto_name.empty()); - - const auto& it = tensor_attribute_values_.find(tensor_proto_name); - if (it != tensor_attribute_values_.end()) { - result = it->second.get(); - return Status::OK(); - } - auto tensor_proto_to_ort_value = [&](const ONNX_NAMESPACE::TensorProto& tensor_proto, - const OrtValue*& result) -> Status { + OrtValue** result) -> Status { const auto& graph_viewer = ep_graph_->GetGraphViewer(); // Initialize OrtValue for tensor attribute. @@ -276,15 +267,13 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, const Or ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, tensor_attribute_allocator, *tensor_attribute_value)); - result = tensor_attribute_value.get(); - tensor_attribute_values_.emplace(tensor_proto_name, std::move(tensor_attribute_value)); - + *result = tensor_attribute_value.release(); return Status::OK(); }; const auto& tensor_proto = attr_proto->t(); - // Create and cache an OrtValue for the 'TENSOR' attribute + // Create and returns an OrtValue for the 'TENSOR' attribute ORT_RETURN_IF_ERROR(tensor_proto_to_ort_value(tensor_proto, result)); return Status::OK(); } @@ -347,22 +336,6 @@ const std::string& EpNode::GetEpName() const { return node_.GetExecutionProviderType(); } -const std::string EpNode::GetUniqueTensorAttributeName(const OrtOpAttr* attr) const { - const auto* attr_proto = reinterpret_cast(attr); - - if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - return ""; - } - - const auto& tensor_proto = attr_proto->t(); - std::string tensor_proto_name = node_.Name() + "_" + attr_proto->name(); - if (!tensor_proto.name().empty()) { - tensor_proto_name += "_" + attr_proto->name(); - } - - return tensor_proto_name; -} - // // EpValueInfo // diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 2a8ec60278768..28d5ff2008b3d 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -184,7 +184,7 @@ struct EpNode : public OrtNode { Status GetAttributes(gsl::span attrs) const override; Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - const OrtValue*& attr_tensor) override; + OrtValue** attr_tensor) const override; // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; @@ -219,10 +219,6 @@ struct EpNode : public OrtNode { // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; - // Helper to get the unique name for the 'TENSOR' attribute. Returns empty string if - // attribute is not 'TENSOR' type. - const std::string GetUniqueTensorAttributeName(const OrtOpAttr* attr) const; - private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index a465f1a43755b..dbae6b61ef515 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,7 +137,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, const OrtValue*& /*attr_tensor*/) override { + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue** /*attr_tensor*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b5f5c8528dccd..06655ae3ca850 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3018,7 +3018,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_ const OrtValue** attr_tensor) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { API_IMPL_BEGIN if (attr_tensor == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); @@ -3034,7 +3034,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo ORT_ENFORCE(!utils::HasExternalData(tensor_proto), "Tensor proto with external data for value attribute is not supported."); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, attr_tensor)); return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 91eab963c4562..3eee174ff81f4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -680,7 +680,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, - _Outptr_ const OrtValue** attr_tensor); + _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); From 53b355e9b05aec490139c88f581b521302ba9190 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 29 Jul 2025 21:39:02 -0700 Subject: [PATCH 18/20] address reviewer comments --- .../core/session/onnxruntime_c_api.h | 4 +-- onnxruntime/core/graph/ep_api_types.cc | 26 ++++++------------- onnxruntime/core/graph/ep_api_types.h | 1 - onnxruntime/test/ep_graph/test_ep_graph.cc | 2 +- 4 files changed, 11 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 282ed8da5f295..c1c189ad00d9e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6070,8 +6070,8 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. - * \param[out] attr_tensor Output parameter set to the 'TENSOR' attribute value or nullptr - * if it's not a 'TENSOR' attribute. Must be freed with OrtApi::ReleaseValue. + * \param[out] attr_tensor Returns the newly created OrtValue if it's a 'TENSOR' attribute. + Must be freed with OrtApi::ReleaseValue. * * \snippet{doc} snippets.dox OrtStatus Return Value * diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index b76daae2fd4ba..93248c261ceff 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -252,29 +252,19 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue const auto* attr_proto = reinterpret_cast(attribute); if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - *result = nullptr; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); } - auto tensor_proto_to_ort_value = [&](const ONNX_NAMESPACE::TensorProto& tensor_proto, - OrtValue** result) -> Status { - const auto& graph_viewer = ep_graph_->GetGraphViewer(); - - // Initialize OrtValue for tensor attribute. - // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - *result = tensor_attribute_value.release(); - return Status::OK(); - }; - + const auto& graph_viewer = ep_graph_->GetGraphViewer(); const auto& tensor_proto = attr_proto->t(); - // Create and returns an OrtValue for the 'TENSOR' attribute - ORT_RETURN_IF_ERROR(tensor_proto_to_ort_value(tensor_proto, result)); + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + *result = tensor_attribute_value.release(); return Status::OK(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 28d5ff2008b3d..94a43553f9a80 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -231,7 +231,6 @@ struct EpNode : public OrtNode { std::unordered_map> attributes_map_; std::vector attributes_; - std::unordered_map> tensor_attribute_values_; // The 'TENSOR' Attribute as an OrtValue std::vector implicit_inputs_; std::vector subgraphs_; diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index d37ee77f5792d..188edad572182 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -490,7 +490,7 @@ TEST(EpGraphTest, SerializeToProto_ConstantOfShape) { int64_t& offset) -> Ort::Status { // OrtValueInfo* could be used to query initializer's name, type, shape, // node consumers, etc. - (void)value_info; + static_cast(value_info); if (bytes <= 127) { is_external = false; // Keep small initializers stored inside the TensorProto. From 0b6a9413e8eb6e24ba490ab9ab40f8b108b7d52d Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 30 Jul 2025 09:44:35 -0700 Subject: [PATCH 19/20] Move tensor attribute check into EpNode::GetTensorAttributeAsOrtValue --- onnxruntime/core/graph/ep_api_types.cc | 6 ++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 7 ------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 93248c261ceff..5bde9024baca6 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -258,6 +258,12 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue const auto& graph_viewer = ep_graph_->GetGraphViewer(); const auto& tensor_proto = attr_proto->t(); + // Check that TensorProto is valid. + ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + // Initialize OrtValue for tensor attribute. auto tensor_attribute_value = std::make_unique(); AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 06655ae3ca850..89a27b8e07cfb 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3027,13 +3027,6 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - const auto& tensor_proto = reinterpret_cast(attribute)->t(); - - ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); - ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); - ORT_ENFORCE(!utils::HasExternalData(tensor_proto), - "Tensor proto with external data for value attribute is not supported."); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, attr_tensor)); return nullptr; API_IMPL_END From d9e40c09eceec185b5d14d608a0e5bcd702e7aab Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 30 Jul 2025 11:20:03 -0700 Subject: [PATCH 20/20] address reviewer comments --- include/onnxruntime/core/session/onnxruntime_c_api.h | 2 +- onnxruntime/core/graph/abi_graph_types.h | 6 +++--- onnxruntime/core/graph/ep_api_types.cc | 4 ++-- onnxruntime/core/graph/ep_api_types.h | 2 +- onnxruntime/core/graph/model_editor_api_types.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c1c189ad00d9e..2899a219bdda0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6070,7 +6070,7 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. - * \param[out] attr_tensor Returns the newly created OrtValue if it's a 'TENSOR' attribute. + * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. Must be freed with OrtApi::ReleaseValue. * * \snippet{doc} snippets.dox OrtStatus Return Value diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index dd0e786b68bc2..504b102e782fd 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -255,11 +255,11 @@ struct OrtNode { /// Gets the node's 'TENSOR' attribute as an OrtValue. /// /// Node's 'TENSOR' attribute. - /// Output parameter set to the 'TENSOR' attribute value or nullptr - /// if it's not a 'TENSOR' attribute. + /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, + /// only if the attribute is of type 'TENSOR' /// A status indicating success or an error. virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, - OrtValue** value) const = 0; + OrtValue*& value) const = 0; /// /// Gets the number of node subgraphs. diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 5bde9024baca6..eb7fb6937c29e 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -248,7 +248,7 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } -Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue** result) const { +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { const auto* attr_proto = reinterpret_cast(attribute); if (attr_proto->type() != onnx::AttributeProto::TENSOR) { @@ -270,7 +270,7 @@ Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, tensor_attribute_allocator, *tensor_attribute_value)); - *result = tensor_attribute_value.release(); + result = tensor_attribute_value.release(); return Status::OK(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 94a43553f9a80..be78d77360cb8 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -184,7 +184,7 @@ struct EpNode : public OrtNode { Status GetAttributes(gsl::span attrs) const override; Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - OrtValue** attr_tensor) const override; + OrtValue*& attr_tensor) const override; // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index dbae6b61ef515..d3795d911b22f 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,7 +137,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue** /*attr_tensor*/) const override { + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 89a27b8e07cfb..4c7b4d7b29c2f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3027,7 +3027,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, attr_tensor)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); return nullptr; API_IMPL_END }