Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@
return Ort::Status{nullptr};
}

static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {

Check warning on line 655 in include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

View workflow job for this annotation

GitHub Actions / build_x64_release

'ort_node': unreferenced formal parameter

Check failure on line 655 in include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

View workflow job for this annotation

GitHub Actions / build_x64_release

the following warning is treated as an error

Check warning on line 655 in include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'ort_node': unreferenced formal parameter

Check failure on line 655 in include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

the following warning is treated as an error
const OrtApi& ort_api = Ort::GetApi();

const char* attr_name = nullptr;
Expand Down Expand Up @@ -766,7 +766,7 @@
// TensorProto as an attribute value doesn't require a name.

OrtValue* ort_value = nullptr;
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value));
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value));

Ort::Value tensor(ort_value);

Expand Down
3 changes: 1 addition & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6079,7 +6079,6 @@ struct OrtApi {

/** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue.
*
* \param[in] node The OrtNode instance.
* \param[in] attribute The OrtOpAttr instance.
* \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue.
Must be freed with OrtApi::ReleaseValue.
Expand All @@ -6088,7 +6087,7 @@ struct OrtApi {
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute,
_Outptr_result_maybenull_ OrtValue** attr_tensor);

/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.
Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/graph/abi_graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,6 @@ struct OrtNode {
/// <returns>A status indicating success or an error.</returns>
virtual onnxruntime::Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const = 0;

/// <summary>
/// Gets the node's 'TENSOR' attribute as an OrtValue.
/// </summary>
/// <param name="attr">Node's 'TENSOR' attribute.</param>
/// <param name="value">Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value,
/// only if the attribute is of type 'TENSOR'</param>
/// <returns>A status indicating success or an error.</returns>
virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr,
OrtValue*& value) const = 0;

/// <summary>
/// Gets the number of node subgraphs.
/// </summary>
Expand Down
26 changes: 0 additions & 26 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,32 +249,6 @@ Status EpNode::GetAttributes(gsl::span<const OrtOpAttr*> dst) const {
return Status::OK();
}

Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const {
const auto* attr_proto = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attribute);

if (attr_proto->type() != onnx::AttributeProto::TENSOR) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute");
}

const auto& graph_viewer = ep_graph_->GetGraphViewer();
const auto& tensor_proto = attr_proto->t();

// Check that TensorProto is valid.
ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type.");
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type.");
ORT_ENFORCE(!utils::HasExternalData(tensor_proto),
"Tensor proto with external data for value attribute is not supported.");

// Initialize OrtValue for tensor attribute.
auto tensor_attribute_value = std::make_unique<OrtValue>();
AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance();
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto,
tensor_attribute_allocator, *tensor_attribute_value));

result = tensor_attribute_value.release();
return Status::OK();
}

Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const {
num_subgraphs = subgraphs_.size();
return Status::OK();
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ struct EpNode : public OrtNode {
// Gets the node's attributes.
Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const override;

Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute,
OrtValue*& attr_tensor) const override;

// Gets the number of subgraphs contained by this node.
Status GetNumSubgraphs(size_t& num_subgraphs) const override;

Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/core/graph/model_editor_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ struct ModelEditorNode : public OrtNode {
"OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode");
}

Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode");
}

Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtModelEditorApi does not support getting the subgraphs for OrtNode");
Expand Down
30 changes: 27 additions & 3 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3036,7 +3036,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) {
ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) {
API_IMPL_BEGIN
if (attr_tensor == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null");
Expand All @@ -3045,7 +3045,31 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null");
}

ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor));
const auto* attr_proto = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attribute);

if (attr_proto->type() != onnx::AttributeProto::TENSOR) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute");
}

const auto& tensor_proto = attr_proto->t();

// Check that TensorProto is valid.
ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type.");
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type.");
ORT_ENFORCE(!utils::HasExternalData(tensor_proto),
"Tensor proto with external data for value attribute is not supported.");

// Initialize OrtValue for tensor attribute.
auto tensor_attribute_value = std::make_unique<OrtValue>();
AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance();
// The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file.
// Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path.
std::filesystem::path model_path;
ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
tensor_attribute_allocator, *tensor_attribute_value));

*attr_tensor = tensor_attribute_value.release();

return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -4134,7 +4158,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Node_GetNumAttributes,
&OrtApis::Node_GetAttributes,
&OrtApis::Node_GetAttributeByName,
&OrtApis::Node_GetTensorAttributeAsOrtValue,
&OrtApis::OpAttr_GetTensorAttributeAsOrtValue,
&OrtApis::OpAttr_GetType,
&OrtApis::OpAttr_GetName,
&OrtApis::Node_GetNumSubgraphs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node,
_Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes);
ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
_Outptr_result_maybenull_ const OrtOpAttr** attribute);
ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute,
_Outptr_result_maybenull_ OrtValue** attr_tensor);
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name);
Expand Down
Loading