Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
onnx_test_runner_common onnxruntime_test_utils onnxruntime_common
onnxruntime onnxruntime_flatbuffers onnx_test_data_proto
${onnxruntime_EXTERNAL_LIBRARIES}
${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
absl::flags absl::flags_parse ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
if(NOT WIN32)
if(onnxruntime_USE_SNPE)
list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe)
Expand All @@ -1272,7 +1272,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32)
endif()
else()
target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs})
target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common absl::flags absl::flags_parse ${onnx_test_libs})
endif()
set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")

Expand Down
7 changes: 6 additions & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@ class IExecutionProvider {
/**
Get the device id of current execution provider
*/
virtual int GetDeviceId() const { return default_device_.Id(); };
virtual int GetDeviceId() const { return default_device_.Id(); }

/**
* Get the OrtDevice the execution provider was registered with.
*/
const OrtDevice& GetDevice() const { return default_device_; }

/**
Get execution provider's configuration options.
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ struct OrtDevice {
return alignment < other.alignment;
}

bool EqualIgnoringAlignment(const OrtDevice& other) const {
return device_type == other.device_type &&
memory_type == other.memory_type &&
vendor_id == other.vendor_id &&
device_id == other.device_id;
}

private:
// Device type.
int32_t device_type : 8;
Expand Down
49 changes: 39 additions & 10 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
#endif

#if !defined(ORT_MINIMAL_BUILD)
/** Gets the GraphProto representation of this Graph only. */
/** Gets the GraphProto representation of this Graph only.
* This does not remove in-memory tags for graph initializers.
* Use ToGraphProto() const to get a GraphProto that can be serialized externally.
*/
const ONNX_NAMESPACE::GraphProto& ToGraphProto();

/// <summary>
Expand Down Expand Up @@ -1439,6 +1442,27 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return Resolve(default_options);
}

/// <summary>
/// This function converts all the graph TensorProto initializers into OrtValues
/// and creates a in-memory external data reference for each OrtValue.
/// </summary>
/// <returns></returns>
Status ConvertInitializersIntoOrtValues();

/**
* @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto.
*
* This function converts specified TensorProto initializers in the graph into OrtValues and
* creates in-memory external data references for each OrtValue. It then updates the provided
* GraphProto with the modified initializers.
*
* @param iterators Span of iterators pointing to the initializers and the order that should be processed
* @param output_graph_proto The GraphProto to be updated with the modified initializers
* @return Status Returns a Status object indicating success or any errors that occurred during conversion
*/
Status RegenerateInitializersAndReplaceInMemory(gsl::span<const InitializedTensorSet::const_iterator> iterators,
ONNX_NAMESPACE::GraphProto& output_graph_proto) const;

const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept {
return outer_scope_node_arg_names_;
}
Expand Down Expand Up @@ -1595,20 +1619,25 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
/// This function is used by ToGraphProto() to ensure in-memory external data references
/// don't leak externally since they are non-standard.
///
/// It handles two scenarios:
/// - When GraphSynchronizationNeeded() is false: GraphProto is simply copied
/// It is used when GraphSynchronizationNeeded() is false: GraphProto is simply copied
/// from graph_proto_ by ToGraphProto(). This copy includes both main graph
/// and subgraph initializers. This function examines all initializers
/// and inlines any in-memory data references.
/// - When GraphSynchronizationNeeded() is true: ToGraphProto() generates a new GraphProto
/// using ToGraphProtoInternal(). This doesn't transfer main graph initializers, which are
/// copied and inlined by ToGraphProto() itself. This function processes only the subgraph initializers
/// as needed.
/// </summary>
/// <param name="output_graph_proto">The GraphProto to process</param>
/// <param name="process_main">Whether to process the main graph initializers</param>
/// <returns>Status indicating success or failure</returns> ///
Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto, bool process_main) const;
/// <returns>Status indicating success or failure</returns>
Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;

/// <summary>
/// This function replaces all of the initializers within output_graph_proto
/// from this Graph instance. All in memory initializers are regenerated and inlined.
/// This is necessary even if the graph_proto_ is already up to date because initializers() may
/// contain obsolete initializers that are no longer in use due to optimizations and contain obsolete
/// references to OrtValues that may no longer be around (since we like appending rather than replacing).
/// </summary>
/// <param name="output_graph_proto">Destination GraphProto to receive the updated initializers.</param>
/// <returns>Status indicating success or failure.</returns>
Status RegenerateInitializersAndReplaceInMemory(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;

/// <summary>
/// This function traverses the graph bottom up and externalizes
Expand Down
11 changes: 10 additions & 1 deletion include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class Environment {
return shared_allocators_;
}

/**
* Returns an AllocatorPtr for a shared IAllocator based allocator if it matches the memory info.
* The OrtMemoryInfo name and whether it's an arena or device allocator is ignored in the lookup, as is the
* alignment.
* The user calling this function is not expected to know the alignment, and we expect the allocator instance to be
* created with a valid alignment for the device.
*/
AllocatorPtr GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const;

/**
* Removes registered allocator that was previously registered for sharing between multiple sessions.
*/
Expand Down Expand Up @@ -171,7 +180,7 @@ class Environment {
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};

std::mutex mutex_;
mutable std::mutex mutex_;

// shared allocators from various sources.
// CreateAndRegisterAllocator[V2]: IAllocator allocators created by ORT
Expand Down
182 changes: 78 additions & 104 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,16 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {

// Set attributes.
proto.clear_attribute();
for (const auto& attribute : attributes_) {
for (const auto& [name, attribute] : attributes_) {
const gsl::not_null<AttributeProto*> attr{proto.add_attribute()};
*attr = attribute.second; // copy
if (update_subgraphs && attr->has_g()) {
*attr = attribute; // copy
if (update_subgraphs && utils::HasGraph(*attr)) {
auto find_hit = attr_to_subgraph_map_.find(name);
// Force ToGraphProto() const to be called so
// that any in-memory TensorProto initializers go back to being inlined
const Graph& subgraph = *find_hit->second;
attr->clear_g();
*attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto();
*attr->mutable_g() = subgraph.ToGraphProto();
}
}

Expand Down Expand Up @@ -3381,7 +3385,12 @@ Status Graph::Resolve(const ResolveOptions& options) {

return Status::OK(); };

ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func));
return ForThisAndAllSubgraphs(all_subgraphs, finalize_func);
}

Status Graph::ConvertInitializersIntoOrtValues() {
std::vector<Graph*> all_subgraphs;
FindAllSubgraphs(all_subgraphs);

auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
Expand Down Expand Up @@ -4308,11 +4317,47 @@ Status InlineOrCopyInitializer(const Graph& src_graph, const ONNX_NAMESPACE::Ten
}
return Status::OK();
}

} // namespace

Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto,
bool process_main) const {
Status Graph::RegenerateInitializersAndReplaceInMemory(gsl::span<const InitializedTensorSet::const_iterator> iterators,
ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
auto& mutable_initializers = *output_graph_proto.mutable_initializer();

#if !defined(DISABLE_SPARSE_TENSORS)
output_graph_proto.clear_sparse_initializer();

const auto& model_path = ModelPath();
const bool has_sparse_initializers = !sparse_tensor_names_.empty();
const auto sparse_end = sparse_tensor_names_.end();

for (const auto& iter : iterators) {
const auto& [name, tensor_proto] = *iter;
const auto& initializer = *tensor_proto;
if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(name)) {
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
*mutable_initializers.Add()));
} else {
auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
if (utils::HasExternalDataInMemory(initializer)) {
ONNX_NAMESPACE::TensorProto tensor_proto_inlined;
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
tensor_proto_inlined));
ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto_inlined, model_path, sparse_initializer));
} else {
ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
}
}
}
#else
for (const auto& iter : iterators) {
const auto& [name, tensor_proto] = *iter;
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, *tensor_proto, *mutable_initializers.Add()));
}
#endif
return Status::OK();
}

Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
for (const auto& node : Nodes()) {
if (node.ContainsSubgraph()) {
// Let's find this node in the output_graph_proto
Expand Down Expand Up @@ -4343,103 +4388,48 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr
"Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ",
node.Name(), " while attempting to recurse into it.");
auto& result_subgraph = *sub_hit->mutable_g();
ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph, process_main));
ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph));
}
}
}

// When graph_proto is copied from graph_proto, initializers already present in the main graph
if (parent_graph_ != nullptr || process_main) {
#if !defined(DISABLE_SPARSE_TENSORS)
auto* mutable_initializers = output_graph_proto.mutable_initializer();
const auto& model_path = ModelPath();
const bool has_sparse_initializers = !sparse_tensor_names_.empty();
const auto sparse_end = sparse_tensor_names_.end();

// We want to make sure that sparse initializers do not appear
// as dense duplicates within the initializers list.
std::optional<InlinedHashSet<std::string>> initializer_to_remove;
if (has_sparse_initializers) {
// We need to remove the dense initializers that are sparse tensors
initializer_to_remove.emplace();
}

for (auto first = mutable_initializers->begin(), end = mutable_initializers->end(); first != end; ++first) {
auto& initializer = *first;
if (utils::HasExternalDataInMemory(initializer)) {
// If the initializer has external data in memory, we need to inline it.
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
}
if (has_sparse_initializers && sparse_end != sparse_tensor_names_.find(initializer.name())) {
auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
initializer_to_remove->insert(initializer.name());
}
}

// erase/remove dense initializers that are sparse tensors so no duplicates are present
if (initializer_to_remove && !initializer_to_remove->empty()) {
mutable_initializers->erase(std::remove_if(
mutable_initializers->begin(), mutable_initializers->end(),
[&initializer_to_remove](const ONNX_NAMESPACE::TensorProto& initializer) {
return initializer_to_remove->count(initializer.name()) > 0;
}),
mutable_initializers->end());
}
#else
for (auto& initializer : *output_graph_proto.mutable_initializer()) {
if (utils::HasExternalDataInMemory(initializer)) {
// If the initializer has external data in memory, we need to inline it.
ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
}
// Filter in iterators for weights that are present in the name_to_initial_tensor_ map
// and preserve the order. This is needed for tests.
InlinedVector<InitializedTensorSet::const_iterator> initializers_to_process;
initializers_to_process.reserve(name_to_initial_tensor_.size());
for (const auto& tensor_proto : output_graph_proto.initializer()) {
auto hit = name_to_initial_tensor_.find(tensor_proto.name());
if (hit != name_to_initial_tensor_.end()) {
initializers_to_process.push_back(hit);
}
#endif
}
return Status::OK();

output_graph_proto.clear_initializer();
return RegenerateInitializersAndReplaceInMemory(initializers_to_process, output_graph_proto);
}

ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const {
GraphProto result;
if (!GraphProtoSyncNeeded()) {
result = *graph_proto_;
ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ true));
ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result));
} else {
// Recursion is handled via Node::ToProto() const -> Graph::ToGraphProto() const (this method)
// so below we handle this graph only.
ToGraphProtoInternal(result);

ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ false));

// Add initializers to parent graph by copy converting them from graph_proto_
// ToGraphProtoInternal() does not copy initializers for the main graph
auto* mutable_initializers = result.mutable_initializer();

#if !defined(DISABLE_SPARSE_TENSORS)
const auto& model_path = ModelPath();
const bool has_sparse_initializers = !sparse_tensor_names_.empty();
const auto sparse_end = sparse_tensor_names_.end();

for (const auto& initializer : graph_proto_->initializer()) {
if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) {
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
*mutable_initializers->Add()));
} else {
auto& sparse_initializer = *result.add_sparse_initializer();
if (utils::HasExternalDataInMemory(initializer)) {
ONNX_NAMESPACE::TensorProto tensor_proto;
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
tensor_proto));
ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto, model_path, sparse_initializer));
} else {
ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
}
InlinedVector<InitializedTensorSet::const_iterator> initializers_to_process;
initializers_to_process.reserve(name_to_initial_tensor_.size());
for (const auto& tensor_proto : graph_proto_->initializer()) {
auto hit = name_to_initial_tensor_.find(tensor_proto.name());
if (hit != name_to_initial_tensor_.end()) {
initializers_to_process.push_back(hit);
}
}
#else
for (const auto& initializer : graph_proto_->initializer()) {
ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer, *mutable_initializers->Add()));
}
#endif
}

ORT_THROW_IF_ERROR(RegenerateInitializersAndReplaceInMemory(initializers_to_process,
result));
}
return result;
}

Expand Down Expand Up @@ -5235,23 +5225,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod
tensor_proto.set_name(std::string(new_name.value()));
}

// In the constant node, we won't have symbolic dims.
const auto tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto);
auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType();
const size_t size_in_bytes = Tensor::CalculateTensorStorageSize(ml_data, tensor_shape);

if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));

constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
ORT_RETURN_IF_ERROR(AddInitializedOrtValue(tensor_proto_to_add, ort_value));
} else {
AddInitializedTensor(tensor_proto);
}
AddInitializedTensor(tensor_proto);

if (GetNodeArg(tensor_proto.name()) == nullptr) {
TypeProto t{utils::TypeProtoFromTensorProto(tensor_proto)};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow<size_t>(element_count) * sizeof(MLFloat16));
}

return graph_utils::AddInitializerWithExternalData(graph, initializer);
return graph_utils::AddInitializer(graph, initializer);
}

static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type,
Expand Down
Loading
Loading