Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[EP ABI] Node_GetAttrByName returns ORT_NOT_FOUND with non-existing a…
…ttr name (#25565)

### Description
Updates `Node_GetAttributeByName` to return an error status with code
`ORT_NOT_FOUND` and set the `attribute` output parameter to `NULL` when
called with a non-existing attribute name.

Why? Currently, a caller has to do string comparison of the `OrtStatus`
error message to determine if the attribute does not exist or if another
error occurred. This can be somewhat cumbersome. With this change, the
caller can just check the error code.

### Motivation and Context
Make it easier to use `Node_GetAttributeByName`.
  • Loading branch information
adrianlizarraga authored and snnn committed Jul 30, 2025
commit b5c4b0cf596f8348b75f6f54ba1f63b97ac009d5
4 changes: 2 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ internal enum ErrorCode
ModelLoaded = 8,
NotImplemented = 9,
InvalidGraph = 10,
ShapeInferenceNotRegistered = 11,
RequirementNotRegistered = 12,
ShapeInferenceNotRegistered = 11, // TODO: should be ORT_EP_FAIL
RequirementNotRegistered = 12, // TODO: should be ORT_MODEL_LOAD_CANCELED
}

/// <summary>
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum StatusCode {
EP_FAIL = 11,
MODEL_LOAD_CANCELED = 12,
MODEL_REQUIRES_COMPILATION = 13,
NOT_FOUND = 14,
};

constexpr const char* StatusCodeToString(StatusCode status) noexcept {
Expand Down Expand Up @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept {
return "MODEL_LOAD_CANCELED";
case StatusCode::MODEL_REQUIRES_COMPILATION:
return "MODEL_REQUIRES_COMPILATION";
case StatusCode::NOT_FOUND:
return "NOT_FOUND";
default:
return "GENERAL ERROR";
}
Expand Down Expand Up @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
return HRESULT_FROM_WIN32(ERROR_CANCELLED);
case StatusCode::MODEL_REQUIRES_COMPILATION:
return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED);
case StatusCode::NOT_FOUND:
return HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
default:
return E_FAIL;
}
Expand Down
18 changes: 16 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ typedef enum OrtErrorCode {
ORT_EP_FAIL,
ORT_MODEL_LOAD_CANCELED,
ORT_MODEL_REQUIRES_COMPILATION,
ORT_NOT_FOUND,
} OrtErrorCode;

typedef enum OrtOpAttrType {
Expand Down Expand Up @@ -6032,6 +6033,11 @@ struct OrtApi {
* Typical usage sets this to the result of Node_GetNumAttributes(). An error status is
* returned if `num_attributes` is less than the number of node attributes.
*
* \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value
* is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape'
* attribute is an example of an optional attribute that does not have a constant default value. This function
* does not provide any unset optional attributes without a constant default value.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
Expand All @@ -6043,14 +6049,22 @@ struct OrtApi {
*
* \param[in] node The OrtNode instance.
* \param[in] attribute_name The name of the attribute
* \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr.
* \param[out] attribute Output parameter set to the OrtOpAttr instance if an attribute by the given name exists.
* For an unset optional attribute, `attribute` is set to NULL and a non-error status is
* returned. For an invalid attribute name, `attribute` is set to NULL and an error status with
* code ORT_NOT_FOUND is returned.
*
* \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value
* is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape'
* attribute is an example of an optional attribute that does not have a constant default value. This function
* does not provide any unset optional attributes without a constant default value.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
_Outptr_ const OrtOpAttr** attribute);
_Outptr_result_maybenull_ const OrtOpAttr** attribute);

/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.
*
Expand Down
10 changes: 8 additions & 2 deletions java/src/main/java/ai/onnxruntime/OrtException.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,17 @@ public enum OrtErrorCode {
/** The ONNX graph is invalid. */
ORT_INVALID_GRAPH(10),
/** The ORT execution provider failed. */
ORT_EP_FAIL(11);
ORT_EP_FAIL(11),
/** Model load was canceled. */
ORT_MODEL_LOAD_CANCELED(12),
/** Model requires compilation. */
ORT_MODEL_REQUIRES_COMPILATION(13),
/** Item was not found. */
ORT_NOT_FOUND(14);

private final int value;

private static final OrtErrorCode[] values = new OrtErrorCode[12];
private static final OrtErrorCode[] values = new OrtErrorCode[15];

static {
for (OrtErrorCode ot : OrtErrorCode.values()) {
Expand Down
6 changes: 6 additions & 0 deletions java/src/main/native/OrtJniUtil.c
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,12 @@ jint convertErrorCode(OrtErrorCode code) {
return 10;
case ORT_EP_FAIL:
return 11;
case ORT_MODEL_LOAD_CANCELED:
return 12;
case ORT_MODEL_REQUIRES_COMPILATION:
return 13;
case ORT_NOT_FOUND:
return 14;
default:
return -1; // Unknown error code
}
Expand Down
33 changes: 29 additions & 4 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph,
}
}

#if !defined(ORT_MINIMAL_BUILD)
static bool IsOptionalAttribute(const Node& node, const std::string& attr_name) {
const ONNX_NAMESPACE::OpSchema* op_schema = node.Op();
if (op_schema == nullptr) {
return false;
}

auto attr_schema_iter = op_schema->attributes().find(attr_name);
if (attr_schema_iter == op_schema->attributes().end()) {
return false; // Not an attribute for this operator type.
}

const ONNX_NAMESPACE::OpSchema::Attribute& attr_schema = attr_schema_iter->second;

return !attr_schema.required;
}
#endif // !defined(ORT_MINIMAL_BUILD)

//
// EpNode
//
Expand Down Expand Up @@ -268,13 +286,20 @@ gsl::span<const EpValueInfo* const> EpNode::GetOutputsSpan() const {
return outputs_;
}

const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const {
const OrtOpAttr* EpNode::GetAttribute(const std::string& name, bool& is_unset_optional_attr) const {
auto iter = attributes_map_.find(name);
if (iter == attributes_map_.end()) {
return nullptr;
} else {
if (iter != attributes_map_.end()) {
is_unset_optional_attr = false;
return reinterpret_cast<const OrtOpAttr*>(iter->second.get());
}

#if !defined(ORT_MINIMAL_BUILD)
is_unset_optional_attr = IsOptionalAttribute(node_, name);
#else
// This is not properly set in a minimal build because it does not have access to the operator schema.
is_unset_optional_attr = false;
#endif // !defined(ORT_MINIMAL_BUILD)
return nullptr;
}

const std::string& EpNode::GetEpName() const {
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ struct EpNode : public OrtNode {
// Helper that returns this node's outputs as a span of EpValueInfo pointers.
gsl::span<const EpValueInfo* const> GetOutputsSpan() const;

// Helper that gets the node's attributes by name.
const OrtOpAttr* GetAttribute(const std::string& name) const;
// Helper that gets the node's attributes by name. If the attribute is not set, returns NULL and sets the
// output parameter `is_unset_optional_attr` to true if this is an unset optional attribute.
const OrtOpAttr* GetAttribute(const std::string& name, bool& is_unset_optional_attr) const;

// Helper that gets the execution provider name that this node is assigned to run on.
const std::string& GetEpName() const;
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2993,7 +2993,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) {
ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
_Outptr_result_maybenull_ const OrtOpAttr** attribute) {
API_IMPL_BEGIN
if (attribute == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL");
Expand All @@ -3004,14 +3005,16 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node,
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName.");
}

*attribute = ep_node->GetAttribute(attribute_name);
bool is_unset_optional_attr = false;
*attribute = ep_node->GetAttribute(attribute_name, is_unset_optional_attr);

if (*attribute) {
if (*attribute || is_unset_optional_attr) {
return nullptr;
} else {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist.");
std::ostringstream oss;
oss << "Node attribute does not exist: " << attribute_name;
return OrtApis::CreateStatus(OrtErrorCode::ORT_NOT_FOUND, oss.str().c_str());
}

API_IMPL_END
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_
ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node,
_Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes);
ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
_Outptr_ const OrtOpAttr** attribute);
_Outptr_result_maybenull_ const OrtOpAttr** attribute);
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name);
ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs);
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_exceptions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void RegisterExceptions(pybind11::module& m) {
pybind11::register_exception<EPFail>(m, "EPFail");
pybind11::register_exception<ModelLoadCanceled>(m, "ModelLoadCanceled");
pybind11::register_exception<ModelRequiresCompilation>(m, "ModelRequiresCompilation");
pybind11::register_exception<NotFound>(m, "NotFound");
}

void OrtPybindThrowIfError(onnxruntime::common::Status status) {
Expand Down Expand Up @@ -67,6 +68,8 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) {
throw ModelLoadCanceled(std::move(msg));
case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION:
throw ModelRequiresCompilation(std::move(msg));
case onnxruntime::common::StatusCode::NOT_FOUND:
throw NotFound(std::move(msg));
default:
throw std::runtime_error(std::move(msg));
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct ModelLoadCanceled : std::runtime_error {
struct ModelRequiresCompilation : std::runtime_error {
explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {}
};
struct NotFound : std::runtime_error {
explicit NotFound(const std::string& what) : std::runtime_error(what) {}
};

void RegisterExceptions(pybind11::module& m);

Expand Down
86 changes: 86 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,92 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) {
CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

TEST(EpGraphTest, GetAttributeByName) {
// Load model with a single Conv that has no explicit attributes set.
auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_default_attrs.onnx"));
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";

//
// Pre-check
//

// Original Conv has no explicit attributes but Graph::Resolve() fills in default values for
// 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not
// have statically computable default values, so will not be filled in by Graph::Resolve().
const OrtGraph& ort_graph = test_graph->GetOrtGraph();
const OrtApi& ort_api = Ort::GetApi();

size_t num_nodes = 0;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes));
ASSERT_EQ(num_nodes, 1);

std::vector<const OrtNode*> nodes(num_nodes);
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size()));

const OrtNode* conv_node = nodes[0];
const char* op_type = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type));
ASSERT_STREQ(op_type, "Conv");

size_t num_attrs = 0;
ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs));
ASSERT_EQ(num_attrs, 2);

std::vector<const OrtOpAttr*> attrs(num_attrs);
ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size()));
for (const OrtOpAttr* attr : attrs) {
const char* attr_name_cstr = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr));
std::string_view attr_name = attr_name_cstr;
ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set
}

//
// Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error.
//
{
const OrtOpAttr* attr = nullptr;
Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)};
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(attr, nullptr);
}

//
// Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error.
//
{
const OrtOpAttr* attr = nullptr;
Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)};
ASSERT_FALSE(status.IsOK());
ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND);
ASSERT_EQ(attr, nullptr);
}

//
// Test 3: Get attribute that is known to be set.
//
{
const OrtOpAttr* attr = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr));
ASSERT_NE(attr, nullptr);

OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED;
ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type));
ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING);

std::string auto_pad_val;

// First call to ReadOpAttr gets the total byte size. Second call reads the data.
size_t total_attr_bytes = 0;
Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)};
auto_pad_val.resize(total_attr_bytes);

ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes,
&total_attr_bytes));
ASSERT_EQ(auto_pad_val, "NOTSET");
}
}

// Check correctness of an OrtGraph that has external initializers.
TEST(EpGraphTest, CheckModelExternalInitializers) {
auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx"));
Expand Down
Binary file added onnxruntime/test/testdata/conv_default_attrs.onnx
Binary file not shown.
36 changes: 36 additions & 0 deletions onnxruntime/test/testdata/make_conv_default_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import numpy as np
import onnx


def main():
inp_shape = (1, 2, 8, 8)
input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape)
output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None)

weight_data = [
[[[-1.5, 0.0], [0.2, 1.5]], [[-1.5, 0.0], [0.2, 1.5]]],
[[[-1.0, 0.0], [0.1333, 1.0]], [[-1.0, 0.0], [0.1333, 1.0]]],
]
weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight")
bias = onnx.numpy_helper.from_array(np.array([0.0, 0.0], dtype=np.float32), "bias")
conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0")
graph = onnx.helper.make_graph(
[conv_node],
"Convf32",
[input_0],
[output_0],
initializer=[weight, bias],
)
opset_imports = [onnx.helper.make_opsetid("", 21)]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
model = onnx.shape_inference.infer_shapes(model)

onnx.checker.check_model(model, True)
onnx.save_model(model, "conv_default_attrs.onnx")


if __name__ == "__main__":
main()