Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Expose GetOrtvalueInitializer via provider bridge
  • Loading branch information
yuslepukhin committed Aug 15, 2025
commit 83b67b54c05a8f8e86eaf424de9065d67f246c3d
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,8 @@ struct ProviderHost {
virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0;
// We pass OrtValue by reference here (as opposed to the original Graph function) to avoid header inclusion
virtual Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) = 0;
virtual bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value,
bool check_outer_scope) = 0;
virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span<NodeArg* const>& input_args, const gsl::span<NodeArg* const>& output_args, const NodeAttributes* attributes, const std::string& domain) = 0;
virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span<NodeArg* const>& input_args, const gsl::span<NodeArg* const>& output_args, NodeAttributes&& attributes, const std::string& domain) = 0;
virtual Node& Graph__AddNode(Graph* p, const Node& other) = 0;
Expand Down Expand Up @@ -1074,6 +1076,8 @@ struct ProviderHost {
virtual const ONNX_NAMESPACE::TensorProto* GraphViewer__GetConstantInitializer(const GraphViewer* p,
const std::string& name,
bool check_outer_scope) const = 0;
virtual bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name,
OrtValue& value) = 0;
virtual const Node* GraphViewer__ParentNode(const GraphViewer* p) = 0;
virtual int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept = 0;
virtual int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,10 @@ struct Graph final {
Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& ort_value) {
return g_host->Graph__AddInitializedOrtValue(this, tensor, ort_value);
}
bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value,
bool check_outer_scope = false) const {
return g_host->Graph__GetOrtValueInitializer(this, tensor_name, ort_value, check_outer_scope);
}
Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span<NodeArg* const> input_args, gsl::span<NodeArg* const> output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); }
Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span<NodeArg* const> input_args, gsl::span<NodeArg* const> output_args, NodeAttributes&& attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, std::move(attributes), domain); }
Node& AddNode(const Node& other) { return g_host->Graph__AddNode(this, other); }
Expand Down Expand Up @@ -1124,6 +1128,9 @@ class GraphViewer final {
bool check_outer_scope = true) const {
return g_host->GraphViewer__GetConstantInitializer(this, name, check_outer_scope);
}
bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value) const {
return g_host->GraphViewer__GetOrtValueInitializer(this, tensor_name, ort_value);
}
const Node* ParentNode() const { return g_host->GraphViewer__ParentNode(this); }

int NumberOfNodes() const noexcept { return g_host->GraphViewer__NumberOfNodes(this); }
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,10 @@ struct ProviderHostImpl : ProviderHost {
void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) override { p->AddInitializedTensor(tensor); }
Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor,
const OrtValue& value) override { return p->AddInitializedOrtValue(tensor, value); }
bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value,
bool check_outer_scope) override {
return p->GetOrtValueInitializer(tensor_name, value, check_outer_scope);
}
Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span<NodeArg* const>& input_args, const gsl::span<NodeArg* const>& output_args, const NodeAttributes* attributes, const std::string& domain) override {
return p->AddNode(name, op_type, description, input_args, output_args, attributes, domain);
}
Expand Down Expand Up @@ -1356,6 +1360,10 @@ struct ProviderHostImpl : ProviderHost {
bool check_outer_scope) const override {
return p->GetConstantInitializer(name, check_outer_scope);
}
bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name,
OrtValue& value) override {
return p->GetOrtValueInitializer(tensor_name, value);
}
const Node* GraphViewer__ParentNode(const GraphViewer* p) override { return p->ParentNode(); }
int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept override { return p->NumberOfNodes(); }
int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept override { return p->MaxNodeIndex(); }
Expand Down
Loading