Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 100 additions & 3 deletions include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@
/*out*/ std::vector<int64_t>& dims,
/*out*/ std::vector<std::string>& symbolic_dims);
static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto);
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);

Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
onnx::GraphProto& graph_proto,
Expand Down Expand Up @@ -379,7 +379,7 @@
}

onnx::AttributeProto* attr_proto = node_proto->add_attribute();
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto));
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto));
}
}

Expand Down Expand Up @@ -652,7 +652,7 @@
return Ort::Status{nullptr};
}

static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
const OrtApi& ort_api = Ort::GetApi();

const char* attr_name = nullptr;
Expand Down Expand Up @@ -758,6 +758,103 @@

break;
}
case OrtOpAttrType::ORT_OP_ATTR_TENSOR: {
attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR);

onnx::TensorProto tensor_proto;

// TensorProto as an attribute value doesn't require a name.

OrtValue* ort_value = nullptr;
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value));

Ort::Value tensor(ort_value);

// Get tensor type and shape info
Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo();

// Get tensor type
ONNXTensorElementDataType element_type = type_shape_info.GetElementType();

size_t element_size = 0;
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT);
element_size = sizeof(float);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8);
element_size = sizeof(uint8_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8);
element_size = sizeof(int8_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16);
element_size = sizeof(uint16_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16);
element_size = sizeof(int16_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32);
element_size = sizeof(int32_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64);
element_size = sizeof(int64_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL);
element_size = sizeof(bool);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE);
element_size = sizeof(double);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32);
element_size = sizeof(uint32_t);
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: {
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64);
element_size = sizeof(uint64_t);
break;
}
default: {
std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast<int>(element_type));
return Ort::Status(err_msg.c_str(), ORT_FAIL);
}
}

auto shape = type_shape_info.GetShape();

for (auto& dim : shape) {
tensor_proto.add_dims(dim);
}

size_t element_count = type_shape_info.GetElementCount();
size_t data_bytes = element_count * element_size;
const void* data = tensor.GetTensorData<void>();

// Copy the Ortvalue to TensorProto as raw data
tensor_proto.set_raw_data(data, data_bytes);

*(attr_proto.mutable_t()) = std::move(tensor_proto);

Check warning on line 855 in include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/providers/utils/ort_graph_to_proto.h:855: Add #include <utility> for move [build/include_what_you_use] [4]
break;
}
default: {
std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast<int>(attr_type));
return Ort::Status(err_msg.c_str(), ORT_FAIL);
Expand Down
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ typedef enum OrtOpAttrType {
ORT_OP_ATTR_STRING,
ORT_OP_ATTR_STRINGS,
ORT_OP_ATTR_GRAPH,
ORT_OP_ATTR_TENSOR,
} OrtOpAttrType;

//! @}
Expand Down Expand Up @@ -6065,6 +6066,20 @@ struct OrtApi {
ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
_Outptr_result_maybenull_ const OrtOpAttr** attribute);

/** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue.
*
* \param[in] node The OrtNode instance.
* \param[in] attribute The OrtOpAttr instance.
* \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue.
Must be freed with OrtApi::ReleaseValue.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
_Outptr_result_maybenull_ OrtValue** attr_tensor);

/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.
*
* \param[in] attribute The OrtOpAttr instance.
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/graph/abi_graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@ struct OrtNode {
/// <returns>A status indicating success or an error.</returns>
virtual onnxruntime::Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const = 0;

/// <summary>
/// Gets the node's 'TENSOR' attribute as an OrtValue.
/// </summary>
/// <param name="attr">Node's 'TENSOR' attribute.</param>
/// <param name="value">Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value,
/// only if the attribute is of type 'TENSOR'</param>
/// <returns>A status indicating success or an error.</returns>
virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr,
OrtValue*& value) const = 0;

/// <summary>
/// Gets the number of node subgraphs.
/// </summary>
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,32 @@ Status EpNode::GetAttributes(gsl::span<const OrtOpAttr*> dst) const {
return Status::OK();
}

Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const {
const auto* attr_proto = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attribute);

if (attr_proto->type() != onnx::AttributeProto::TENSOR) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute");
}

const auto& graph_viewer = ep_graph_->GetGraphViewer();
const auto& tensor_proto = attr_proto->t();

// Check that TensorProto is valid.
ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type.");
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type.");
ORT_ENFORCE(!utils::HasExternalData(tensor_proto),
"Tensor proto with external data for value attribute is not supported.");

// Initialize OrtValue for tensor attribute.
auto tensor_attribute_value = std::make_unique<OrtValue>();
AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance();
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto,
tensor_attribute_allocator, *tensor_attribute_value));

result = tensor_attribute_value.release();
return Status::OK();
}

Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const {
num_subgraphs = subgraphs_.size();
return Status::OK();
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ struct EpNode : public OrtNode {
// Gets the node's attributes.
Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const override;

Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute,
OrtValue*& attr_tensor) const override;

// Gets the number of subgraphs contained by this node.
Status GetNumSubgraphs(size_t& num_subgraphs) const override;

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/model_editor_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ struct ModelEditorNode : public OrtNode {
"OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode");
}

Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode");
}

Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtModelEditorApi does not support getting the subgraphs for OrtNode");
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3018,6 +3018,20 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) {
API_IMPL_BEGIN
if (attr_tensor == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null");
}
if (attribute == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null");
}

ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor));
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) {
API_IMPL_BEGIN
const auto attr = attribute->attr_proto;
Expand Down Expand Up @@ -3055,6 +3069,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O
*type = OrtOpAttrType::ORT_OP_ATTR_GRAPH;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: {
*type = OrtOpAttrType::ORT_OP_ATTR_TENSOR;
break;
}
default:
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type.");
}
Expand Down Expand Up @@ -4037,6 +4055,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Node_GetNumAttributes,
&OrtApis::Node_GetAttributes,
&OrtApis::Node_GetAttributeByName,
&OrtApis::Node_GetTensorAttributeAsOrtValue,
&OrtApis::OpAttr_GetType,
&OrtApis::OpAttr_GetName,
&OrtApis::Node_GetNumSubgraphs,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node,
_Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes);
ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
_Outptr_result_maybenull_ const OrtOpAttr** attribute);
ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
_Outptr_result_maybenull_ OrtValue** attr_tensor);
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name);
ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs);
Expand Down
96 changes: 96 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector<float>& outpu
output_data.assign(output_values, output_values + num_output_elems);
}

static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector<float>& output_data) {
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::SessionOptions sess_options;
Ort::Session session(*ort_env, model_path, sess_options);

std::vector<int64_t> input_shape = {3};
std::vector<int64_t> input_data = {2, 3, 4};
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> ort_input_names;

// Add 'x'
ort_inputs.emplace_back(Ort::Value::CreateTensor<int64_t>(
memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size()));
ort_input_names.push_back("x");

// Run session and get outputs
std::array<const char*, 1> output_names{"y"};
std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(),
ort_inputs.size(), output_names.data(), output_names.size());

// Check output type and number of elements.
Ort::Value& ort_output = ort_outputs[0];
auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo();
size_t num_output_elems = output_type_shape.GetElementCount();

ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
ASSERT_EQ(num_output_elems, 24);

// Return output data.
const float* output_values = ort_output.GetTensorData<float>();
output_data.assign(output_values, output_values + num_output_elems);
}

// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file.
// Checks that the outputs of the serialized and original models are identical.
TEST(EpGraphTest, SerializeToProto_Mnist) {
Expand Down Expand Up @@ -436,6 +469,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) {
}
}

// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file.
// Checks that the outputs of the serialized and original models are identical.
TEST(EpGraphTest, SerializeToProto_ConstantOfShape) {
const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx");
const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx");
std::filesystem::remove(serialized_model_path);

{
auto test_graph = TestGraph::Load(original_model_path);
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";

// Serialize OrtGraph to GraphProto. Save initializers to external file.
std::string ext_ini_file_path = "constant_of_shape_serialized.bin";
std::filesystem::remove(ext_ini_file_path);
std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary);
auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info,
const void* data, size_t bytes,
bool& is_external, std::string& location,
int64_t& offset) -> Ort::Status {
// OrtValueInfo* could be used to query initializer's name, type, shape,
// node consumers, etc.
static_cast<void>(value_info);

if (bytes <= 127) {
is_external = false; // Keep small initializers stored inside the TensorProto.
return Ort::Status{nullptr};
}

offset = ext_ini_ofs.tellp();
location = ext_ini_file_path;
ext_ini_ofs.write(static_cast<const char*>(data), bytes);
ext_ini_ofs.flush();
is_external = true; // True if is external initializer.

return Ort::Status{nullptr};
};

ONNX_NAMESPACE::ModelProto model_proto;
ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto,
handle_initializer_data));

std::ofstream ofs(serialized_model_path, std::ios::binary);
model_proto.SerializeToOstream(&ofs);
ofs.flush();

ASSERT_TRUE(std::filesystem::exists(serialized_model_path));
ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path));
}

// Compare output of the original and serialized models. Should be identical.
std::vector<float> output_original;
std::vector<float> output_serialized;

RunConstantOfShapeModel(original_model_path, output_original);
RunConstantOfShapeModel(serialized_model_path, output_serialized);

EXPECT_EQ(output_serialized, output_original);
}

static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector<float>& output_data) {
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::SessionOptions sess_options;
Expand Down Expand Up @@ -978,6 +1070,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR);
break;
}
default:
// The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail.
ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."));
Expand Down
Loading