Skip to content
17 changes: 7 additions & 10 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
#endif

#if !defined(ORT_MINIMAL_BUILD)
/** Gets the GraphProto representation of this Graph only. */
/** Gets the GraphProto representation of this Graph only.
* This does not remove in-memory tags for graph initializers.
* Use ToGraphProto() const to get a GraphProto that can be serialized externally.
*/
const ONNX_NAMESPACE::GraphProto& ToGraphProto();

/// <summary>
Expand Down Expand Up @@ -1595,20 +1598,14 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
/// This function is used by ToGraphProto() to ensure in-memory external data references
/// don't leak externally since they are non-standard.
///
/// It handles two scenarios:
/// - When GraphSynchronizationNeeded() is false: GraphProto is simply copied
/// It is used when GraphSynchronizationNeeded() is false: GraphProto is simply copied
/// from graph_proto_ by ToGraphProto(). This copy includes both main graph
/// and subgraph initializers. This function examines all initializers
/// and inlines any in-memory data references.
/// - When GraphSynchronizationNeeded() is true: ToGraphProto() generates a new GraphProto
/// using ToGraphProtoInternal(). This doesn't transfer main graph initializers, which are
/// copied and inlined by ToGraphProto() itself. This function processes only the subgraph initializers
/// as needed.
/// </summary>
/// <param name="output_graph_proto">The GraphProto to process</param>
/// <param name="process_main">Whether to process the main graph initializers</param>
/// <returns>Status indicating success or failure</returns> ///
Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto, bool process_main) const;
/// <returns>Status indicating success or failure</returns>
Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;

/// <summary>
/// This function traverses the graph bottom up and externalizes
Expand Down
118 changes: 60 additions & 58 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,16 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {

// Set attributes.
proto.clear_attribute();
for (const auto& attribute : attributes_) {
for (const auto& [name, attribute] : attributes_) {
const gsl::not_null<AttributeProto*> attr{proto.add_attribute()};
*attr = attribute.second; // copy
if (update_subgraphs && attr->has_g()) {
*attr = attribute; // copy
if (update_subgraphs && utils::HasGraph(*attr)) {
auto find_hit = attr_to_subgraph_map_.find(name);
// Force ToGraphProto() const to be called so
// that any in-memory TensorProto initializers go back to being inlined
const Graph& subgraph = *find_hit->second;
attr->clear_g();
*attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto();
*attr->mutable_g() = subgraph.ToGraphProto();
}
}

Expand Down Expand Up @@ -4311,8 +4315,7 @@ Status InlineOrCopyInitializer(const Graph& src_graph, const ONNX_NAMESPACE::Ten

} // namespace

Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto,
bool process_main) const {
Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
for (const auto& node : Nodes()) {
if (node.ContainsSubgraph()) {
// Let's find this node in the output_graph_proto
Expand Down Expand Up @@ -4343,99 +4346,98 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr
"Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ",
node.Name(), " while attempting to recurse into it.");
auto& result_subgraph = *sub_hit->mutable_g();
ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph, process_main));
ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph));
}
}
}

// When graph_proto is copied from graph_proto, initializers already present in the main graph
if (parent_graph_ != nullptr || process_main) {
#if !defined(DISABLE_SPARSE_TENSORS)
auto* mutable_initializers = output_graph_proto.mutable_initializer();
const auto& model_path = ModelPath();
const bool has_sparse_initializers = !sparse_tensor_names_.empty();
const auto sparse_end = sparse_tensor_names_.end();
auto* mutable_initializers = output_graph_proto.mutable_initializer();
const auto& model_path = ModelPath();
const bool has_sparse_initializers = !sparse_tensor_names_.empty();
const auto sparse_end = sparse_tensor_names_.end();

// We want to make sure that sparse initializers do not appear
// as dense duplicates within the initializers list.
std::optional<InlinedHashSet<std::string>> initializer_to_remove;
if (has_sparse_initializers) {
// We need to remove the dense initializers that are sparse tensors
initializer_to_remove.emplace();
}
// We want to make sure that sparse initializers do not appear
// as dense duplicates within the initializers list.
std::optional<InlinedHashSet<std::string>> initializer_to_remove;
if (has_sparse_initializers) {
// We need to remove the dense initializers that are sparse tensors
initializer_to_remove.emplace();
}

for (auto first = mutable_initializers->begin(), end = mutable_initializers->end(); first != end; ++first) {
auto& initializer = *first;
if (utils::HasExternalDataInMemory(initializer)) {
// If the initializer has external data in memory, we need to inline it.
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
}
if (has_sparse_initializers && sparse_end != sparse_tensor_names_.find(initializer.name())) {
auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
initializer_to_remove->insert(initializer.name());
}
for (auto first = mutable_initializers->begin(), end = mutable_initializers->end(); first != end; ++first) {
auto& initializer = *first;
if (utils::HasExternalDataInMemory(initializer)) {
// If the initializer has external data in memory, we need to inline it.
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
}

// erase/remove dense initializers that are sparse tensors so no duplicates are present
if (initializer_to_remove && !initializer_to_remove->empty()) {
mutable_initializers->erase(std::remove_if(
mutable_initializers->begin(), mutable_initializers->end(),
[&initializer_to_remove](const ONNX_NAMESPACE::TensorProto& initializer) {
return initializer_to_remove->count(initializer.name()) > 0;
}),
mutable_initializers->end());
if (has_sparse_initializers && sparse_end != sparse_tensor_names_.find(initializer.name())) {
auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
initializer_to_remove->insert(initializer.name());
}
}

// erase/remove dense initializers that are sparse tensors so no duplicates are present
if (initializer_to_remove && !initializer_to_remove->empty()) {
mutable_initializers->erase(std::remove_if(
mutable_initializers->begin(), mutable_initializers->end(),
[&initializer_to_remove](const ONNX_NAMESPACE::TensorProto& initializer) {
return initializer_to_remove->count(initializer.name()) > 0;
}),
mutable_initializers->end());
}
#else
for (auto& initializer : *output_graph_proto.mutable_initializer()) {
if (utils::HasExternalDataInMemory(initializer)) {
// If the initializer has external data in memory, we need to inline it.
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
}
for (auto& initializer : *output_graph_proto.mutable_initializer()) {
if (utils::HasExternalDataInMemory(initializer)) {
// If the initializer has external data in memory, we need to inline it.
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
}
#endif
}
#endif

return Status::OK();
}

ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const {
GraphProto result;
if (!GraphProtoSyncNeeded()) {
result = *graph_proto_;
ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ true));
ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result));
} else {
// Recursion is handled via Node::ToProto() const -> Graph::ToGraphProto() const (this method)
// so below we handle this graph only.
ToGraphProtoInternal(result);

ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ false));

// Add initializers to parent graph by copy converting them from graph_proto_
// ToGraphProtoInternal() does not copy initializers for the main graph
// Add initializers to main graph by copying them from graph_proto_
// ToGraphProtoInternal() does not copy initializers for the graph that it was invoked for.
auto* mutable_initializers = result.mutable_initializer();

#if !defined(DISABLE_SPARSE_TENSORS)
const auto& model_path = ModelPath();
const bool has_sparse_initializers = !sparse_tensor_names_.empty();
const auto sparse_end = sparse_tensor_names_.end();

for (const auto& initializer : graph_proto_->initializer()) {
if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) {
for (const auto& [name, tensor_proto] : name_to_initial_tensor_) {
const auto& initializer = *tensor_proto;
if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(name)) {
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
*mutable_initializers->Add()));
} else {
auto& sparse_initializer = *result.add_sparse_initializer();
if (utils::HasExternalDataInMemory(initializer)) {
ONNX_NAMESPACE::TensorProto tensor_proto;
ONNX_NAMESPACE::TensorProto tensor_proto_inlined;
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
tensor_proto));
ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto, model_path, sparse_initializer));
tensor_proto_inlined));
ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto_inlined, model_path, sparse_initializer));
} else {
ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
}
}
}
#else
for (const auto& initializer : graph_proto_->initializer()) {
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer, *mutable_initializers->Add()));
for (const auto& [name, tensor_proto] : name_to_initial_tensor_) {
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, *tensor_proto, *mutable_initializers->Add()));
}
#endif
}
Expand Down
Loading