Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1454,13 +1454,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return Resolve(default_options);
}

/// <summary>
/// This function converts all the graph TensorProto initializers into OrtValues
/// and creates a in-memory external data reference for each OrtValue.
/// </summary>
/// <returns></returns>
Status ConvertInitializersIntoOrtValues();

/**
* @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto.
*
Expand Down
33 changes: 20 additions & 13 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
std::move(tensor), ort_value);
}
} else {
// for internal initializer, always allocate memory on device - tensor
ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape,
use_device_allocator_for_initializers, alloc));

if (device == default_cpu_device) {
// deserialize directly to CPU tensor
// Do not use arena for internal initializer, just like we do for OrtValue initializers
ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(/* use_device_allocator_for_initializers =*/true,
tensor_shape, type,
default_cpu_alloc, tensor));
ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, tensor));
Tensor::InitOrtValue(std::move(tensor), ort_value);
return common::Status::OK();
Expand All @@ -154,13 +154,19 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators");
}

// Allocate according to the plan on the device or directly on the device according to
// use_device_allocator_for_initializers
ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape,
use_device_allocator_for_initializers, alloc));

// deserialize to CPU first for non-CPU allocator, then copy
// for internal initializer
// 1. allocate memory on CPU - deserialized_tensor
// 2. deserialize tensor_proto into a preallocated tensor (deserialized_tensor)
// 1. allocate memory on CPU - deserialized_tensor. Do not use arena not to waste space for temporary buffers.
// 2. deserialize tensor_proto into a pre-allocated tensor (deserialized_tensor)
// 3. copy tensor from CPU to device - deserialized_tensor -> tensor (allocated above) -> ort_value
Tensor deserialized_tensor;
ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type,
ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(/* use_device_allocator_for_initializers =*/true,
tensor_shape, type,
default_cpu_alloc, deserialized_tensor));

ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, deserialized_tensor));
Expand Down Expand Up @@ -346,6 +352,13 @@ common::Status SaveInitializedTensors(
<< i.second << " bytes for " << i.first.ToString() << std::endl;
}

// ??? Should we ignore this session option if the EP is explicitly providing the read only allocator?
// bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator;
// This option also means to ignore arena if present and use Reserve().
const bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(
kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1";

// 3. create weight tensors based on weights buffer
for (const auto& entry : id_to_initialized_tensor) {
// We check for cancellation for every initializer since mapping from disk can be costly
Expand Down Expand Up @@ -375,12 +388,6 @@ common::Status SaveInitializedTensors(
// TODO: if the tensor need be copied, does it have enough room?
ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc));

// ??? Should we ignore this session option if the EP is explicitly providing the read only allocator?
// bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator;
const bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(
kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1";

// Check if we already have an OrtValue for this initializer on CPU
if (OrtValue ort_value_from_graph;
graph.GetOrtValueInitializer(name, ort_value_from_graph)) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IA
if (len > 0) {
p_data = allocator->Alloc(len);
}
Init(elt_type, shape, p_data, allocator, 0L);
Init(elt_type, shape, p_data, std::move(allocator), 0L);
}

Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr<IAllocator> deleter,
ptrdiff_t offset, gsl::span<const int64_t> strides)
: alloc_info_(deleter->Info()) {
ORT_ENFORCE(elt_type != nullptr);
Init(elt_type, shape, p_data, deleter, offset, strides);
Init(elt_type, shape, p_data, std::move(deleter), offset, strides);
}

void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator,
Expand Down
71 changes: 38 additions & 33 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,26 @@ Graph::Graph(const Model& owning_model,
ArgNameToTypeMap name_to_type_map;
const auto& model_path = ModelPath();

// If the tensor proto data is large enough, externalize it and replace with a tensor_proto
// with external data reference pointing to an OrtValue, otherwise do nothing.
auto put_data_maybe_in_memory = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) {
size_t size_in_bytes = 0;
ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));
constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
assert(ort_value.IsAllocated());
auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value));
ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(),
" in the initializer list.");
tensor_proto = std::move(tensor_proto_to_add);
}
};

// Process 'Constant' nodes
// Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list
for (auto& node : graph_proto_->node()) {
Expand All @@ -1250,6 +1270,8 @@ Graph::Graph(const Model& owning_model,
}
}

put_data_maybe_in_memory(*tensor);

// Ensure initializers are also graph inputs.
if (ir_version_ < 4) {
TypeProto t{utils::TypeProtoFromTensorProto(*tensor)};
Expand Down Expand Up @@ -1326,7 +1348,22 @@ Graph::Graph(const Model& owning_model,
}

// Copy initial tensors to a map.
for (auto& tensor : graph_proto_->initializer()) {
for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) {
auto& tensor = *graph_proto_->mutable_initializer(i);
// If data is on disk, it will be loaded either by optimizers
// or during session state finalization.
// If data is already in memory, do nothing.
if (!utils::HasExternalData(tensor)) {
const bool is_sparse = sparse_tensor_names_.count(tensor.name());
if (is_sparse) {
sparse_tensor_names_.erase(tensor.name());
}
put_data_maybe_in_memory(tensor);
if (is_sparse) {
sparse_tensor_names_.emplace(tensor.name());
}
}

auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor);
if (!p.second) {
LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name()
Expand Down Expand Up @@ -3415,38 +3452,6 @@ Status Graph::Resolve(const ResolveOptions& options) {
return ForThisAndAllSubgraphs(all_subgraphs, finalize_func);
}

Status Graph::ConvertInitializersIntoOrtValues() {
std::vector<Graph*> all_subgraphs;
FindAllSubgraphs(all_subgraphs);

auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
const auto& model_path = graph.ModelPath();
auto& graph_proto = *graph.graph_proto_;
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
auto& tensor_proto = *graph_proto.mutable_initializer(i);
if (utils::HasExternalData(tensor_proto)) {
continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize
}

size_t size_in_bytes = 0;
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));
constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
}
}
return Status::OK();
};

return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func);
}

void Graph::SetName(const std::string& name) {
graph_proto_->set_name(name);
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow<size_t>(element_count) * sizeof(MLFloat16));
}

return graph_utils::AddInitializer(graph, initializer);
return graph_utils::AddInitializerWithExternalData(graph, initializer);
}

static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph,
"total_count: ", total_count, " values.size(): ", values.size());

utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t));
return &graph_utils::AddInitializer(graph, const_tensor);
return &graph_utils::AddInitializerWithExternalData(graph, const_tensor);
}

NodeArg* InsertNodesForValidIndices(Graph& graph,
Expand Down
13 changes: 9 additions & 4 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
ONNX_NAMESPACE::TensorShapeProto result_shape;
result_shape.add_dim()->set_dim_value(clamped_slice_length);
constant_arg_out->SetShape(result_shape);
graph_utils::AddInitializer(graph, shape_constant);
graph_utils::AddInitializerWithExternalData(graph, shape_constant);
}

return is_concrete_shape; // convert to constant if this is true
Expand Down Expand Up @@ -317,19 +317,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph.
auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx];
const Tensor& out_tensor = ort_value.Get<Tensor>();
constexpr const bool use_tensor_buffer_false = false;
constexpr const bool use_tensor_buffer_true = true;
ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(
out_tensor,
constant_arg_out->Name(),
use_tensor_buffer_false);
use_tensor_buffer_true);

ONNX_NAMESPACE::TensorShapeProto result_shape;
for (auto& dim : out_tensor.Shape().GetDims()) {
result_shape.add_dim()->set_dim_value(dim);
}

constant_arg_out->SetShape(result_shape);
graph.AddInitializedTensor(out_tensorproto);
// The data is too small and has been inlined.
if (!utils::HasExternalData(out_tensorproto)) {
ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, OrtValue()));
} else {
ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, ort_value));
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
auto new_name = graph.GenerateNodeArgName("ConvAddFusion_B_" + B_input_name);
new_conv_B_tensor_proto.set_name(new_name);

NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg);

} else {
Expand All @@ -94,7 +94,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
auto new_name = graph.GenerateNodeArgName("ConvAddFusion_Add_B_" + add_B_tensor_proto->name());
new_conv_B_tensor_proto.set_name(new_name);

NodeArg& new_add_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
NodeArg& new_add_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
graph_utils::AddNodeInput(node, 2, new_add_B_node_arg);
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
new_conv_W_tensor_proto.set_name(new_W_name);
new_conv_B_tensor_proto.set_name(new_B_name);

NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto);
NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto);
graph_utils::ReplaceNodeInput(node, 1, new_conv_W_node_arg);

auto& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
auto& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);

if (conv_inputs.size() == 3) {
graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_mul_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
new_conv_W_tensor_proto.set_name(new_W_name);

// Replace initializers of conv node
NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto);
NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto);
graph_utils::ReplaceNodeInput(conv_node, 1, new_conv_W_node_arg);

if (is_3d) {
Expand All @@ -100,7 +100,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
auto new_B_name = graph.GenerateNodeArgName("ConvMulFusion_Mul_B_" + mul_B_tensor_proto->name());
new_conv_B_tensor_proto.set_name(new_B_name);

NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
graph_utils::ReplaceNodeInput(conv_node, 2, new_conv_B_node_arg);
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index,
auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name());
new_input_tensor.set_name(new_name);
new_input_tensor.add_dims(1);
NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor);
NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor);
graph_utils::ReplaceNodeInput(node, index, new_input);
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/embed_layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ static NodeArg* ExtractEmbedding(Graph& graph,
utils::SetRawDataInTensorProto(initializer, data, gsl::narrow<size_t>(element_count) * sizeof(MLFloat16));
}

NodeArg& node_arg = graph_utils::AddInitializer(graph, initializer);
NodeArg& node_arg = graph_utils::AddInitializerWithExternalData(graph, initializer);
modified = true;
return &node_arg;
}
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/optimizer/fuse_initializers_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@
graph.RemoveEdge(node.Index(), next_node.Index(), 0, static_cast<int>(next_node_arg_index));

// Add the new converted Tensor in next node as initializer potentially with external data
ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get<Tensor>(), new_arg_name, false);
auto& new_arg = graph_utils::AddInitializer(graph, dst_tensor);
ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get<Tensor>(), new_arg_name, true);
if (!utils::HasExternalData(dst_tensor)) {
new_data = OrtValue(); // Data is inline
}

auto& new_arg = graph_utils::AddInitializerWithExternalData(graph, dst_tensor, std::move(new_data));

Check warning on line 145 in onnxruntime/core/optimizer/fuse_initializers_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/fuse_initializers_transformer.cc:145: Add #include <utility> for move [build/include_what_you_use] [4]
graph_utils::ReplaceNodeInput(next_node, static_cast<int>(next_node_arg_index), new_arg);
}

Expand Down
Loading
Loading