diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 21aa797ce16eb..28ce4439fdc7e 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -766,7 +766,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // TensorProto as an attribute value doesn't require a name. OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); Ort::Value tensor(ort_value); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9ae6174817b7c..f137d88e5fb8a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6079,7 +6079,6 @@ struct OrtApi { /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. * - * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. Must be freed with OrtApi::ReleaseValue. @@ -6088,7 +6087,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index b99c22edb36c8..2ef7c4a9091f3 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -252,16 +252,6 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; - /// - /// Gets the node's 'TENSOR' attribute as an OrtValue. - /// - /// Node's 'TENSOR' attribute. - /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, - /// only if the attribute is of type 'TENSOR' - /// A status indicating success or an error. - virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, - OrtValue*& value) const = 0; - /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 759a2998ace3a..0d9b93631ee8a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -249,32 +249,6 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } -Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { - const auto* attr_proto = reinterpret_cast(attribute); - - if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); - } - - const auto& graph_viewer = ep_graph_->GetGraphViewer(); - const auto& tensor_proto = attr_proto->t(); - - // Check that TensorProto is valid. - ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); - ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); - ORT_ENFORCE(!utils::HasExternalData(tensor_proto), - "Tensor proto with external data for value attribute is not supported."); - - // Initialize OrtValue for tensor attribute. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - result = tensor_attribute_value.release(); - return Status::OK(); -} - Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 7f22e265129f7..e003f02a79a2d 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,9 +183,6 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - OrtValue*& attr_tensor) const override; - // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index e7ffcbc7e4c90..2c0f6d6174303 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -138,11 +138,6 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); - } - Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ad0a1ad137f06..f3e2a8ce7ba7b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3036,7 +3036,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { API_IMPL_BEGIN if (attr_tensor == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); @@ -3045,7 +3045,39 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + if (!utils::HasDataType(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto doesn't have data type."); + } + + if (!ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto has invalid data type."); + } + + if (utils::HasExternalData(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Tensor proto with external data for value attribute is not supported."); + } + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + // The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file. + // Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path. + std::filesystem::path model_path; + ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + *attr_tensor = tensor_attribute_value.release(); + return nullptr; API_IMPL_END } @@ -4134,7 +4166,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, - &OrtApis::Node_GetTensorAttributeAsOrtValue, + &OrtApis::OpAttr_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index e62149d04a16c..6dc4cf9d195cc 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -687,7 +687,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); -ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, +ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name);