Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h"

#include <gsl/gsl>
#include <algorithm>
#include <memory>
#include <string>
#include <sstream>
#include <unordered_set>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -117,6 +120,16 @@ static OrtDevice GetOrtDeviceForPluginEp(gsl::span<const OrtEpDevice* const> ep_
return device_memory_info != nullptr ? device_memory_info->device : OrtDevice();
}

static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type,
gsl::span<const EpNode* const> ep_nodes) {
auto node_iter = std::find_if(ep_nodes.begin(), ep_nodes.end(),
[&ep_type](const EpNode* node) -> bool {
return node->GetInternalNode().GetExecutionProviderType() != ep_type;
});

return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr;
}

PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options,
OrtEpFactory& ep_factory,
gsl::span<const OrtEpDevice* const> ep_devices,
Expand Down Expand Up @@ -158,17 +171,33 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs
ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed?

const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger();
auto log_unsupported_node_info = [&ep_type = Type(), &logger](gsl::span<const EpNode* const> ep_nodes) {
std::ostringstream oss;
oss << "OrtEp::GetCapability() specified nodes that cannot be assigned to " << ep_type << ". ";

if (const Node* node_for_other_ep = FindFirstNodeAssignedToOtherEP(ep_type, ep_nodes);
node_for_other_ep != nullptr) {
oss << "Found one or more nodes that were already assigned to a different EP named '"
<< node_for_other_ep->GetExecutionProviderType() << "'. Ex: "
<< node_for_other_ep->OpType() << " node with name '"
<< node_for_other_ep->Name() << "'.";
}

LOGS(logger, WARNING) << oss.str();
};

std::unique_ptr<EpGraph> ep_graph = nullptr;
if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString();
LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString();
return {};
}

OrtEpGraphSupportInfo api_graph_support_info(*ep_graph);
Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info));

if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString();
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " failed with error: " << status.ToString();
return {};
}

Expand All @@ -188,6 +217,14 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index());
result.push_back(std::make_unique<ComputeCapability>(std::move(indexed_sub_graph)));
} else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) {
if (node_grouping.nodes.empty()) {
// The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide
// an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes.
LOGS(logger, WARNING) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes "
<< "when specifying supported nodes.";
return {};
}

std::unordered_set<const Node*> node_set;
node_set.reserve(node_grouping.nodes.size());

Expand All @@ -207,27 +244,34 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
this->Type(), this->Type(), /*node_unit_map*/ nullptr,
node_grouping.fusion_options.drop_constant_initializers);

// Check if utils::CreateSupportedPartitions returned zero results.
// Happens if nodes have already been assigned to another EP.
if (capabilities.empty()) {
log_unsupported_node_info(node_grouping.nodes);
return {};
}

if (capabilities.size() > 1) {
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. "
<< "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not "
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set nodes that cannot be fused together. "
<< "Please ensure that the nodes provided to EpGraphSupportInfo_AddNodesToFuse() do not "
<< "have an unsupported node in any path between two of the supported nodes.";
return {};
}

// Enforce that the nodes in node_set match the nodes in capabilities[0]
// Log if the nodes in node_set do not match the nodes in capabilities[0], which occurs when EP selects nodes
// assigned to a different EP.
// TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above.
std::vector<NodeIndex>& capability_node_indices = capabilities[0]->sub_graph->nodes;
std::unordered_set<NodeIndex> capability_node_indices_set(capability_node_indices.begin(),
capability_node_indices.end());

ORT_ENFORCE(node_set.size() == capability_node_indices_set.size());
ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) {
return capability_node_indices_set.count(node->Index()) != 0;
}));
if (node_set.size() != capability_node_indices_set.size()) {
log_unsupported_node_info(node_grouping.nodes);
}

result.push_back(std::move(capabilities[0]));
} else {
LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
<< static_cast<int>(node_grouping.kind);
return {};
}
Expand Down
134 changes: 134 additions & 0 deletions onnxruntime/test/framework/ep_plugin_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h"

#include <filesystem>
#include "gsl/gsl"
#include "gtest/gtest.h"

#include "core/common/logging/sinks/file_sink.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/graph_optimizer_registry.h"
#include "core/session/abi_devices.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/util/include/asserts.h"
Expand All @@ -23,6 +28,14 @@ struct ApiPtrs {
const gsl::not_null<const ::OrtEpApi*> ep_api;
};

static void CheckStringInFile(const PathString& filename, const std::string& look_for) {
std::ifstream ifs{filename};
std::string content(std::istreambuf_iterator<char>{ifs},
std::istreambuf_iterator<char>{});

EXPECT_NE(content.find(look_for), std::string::npos);
}

// Normally, a plugin EP would be implemented in a separate library.
// The `test_plugin_ep` namespace contains a local implementation intended for unit testing.
namespace test_plugin_ep {
Expand Down Expand Up @@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector<const OrtEpDevice*> ep_devices = {
return result;
}

class MockKernelLookup : public IExecutionProvider::IKernelLookup {
const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; }
};

} // namespace test_plugin_ep

TEST(PluginExecutionProviderTest, GetPreferredLayout) {
Expand Down Expand Up @@ -317,4 +334,121 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) {
#endif // !defined(ORT_NO_EXCEPTIONS)
}

static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodes(OrtEp* this_ptr, const OrtGraph* graph,
OrtEpGraphSupportInfo* graph_support_info) noexcept {
auto* this_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);

size_t num_nodes = 0;
if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) {
return st;
}

std::vector<const OrtNode*> nodes(num_nodes);
if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) {
return st;
}

if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
nodes.data(), nodes.size(), nullptr);
st != nullptr) {
return st;
}

return nullptr;
}

// Tests that GetCapability() doesn't crash if a plugin EP tries to claim only nodes that are
// already assigned to another EP.
TEST(PluginExecutionProviderTest, GetCapability_ClaimOnlyNodesAssignedToOtherEP) {
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp();
ort_ep->GetCapability = GetCapabilityTakeAllNodes;

// Load a model and forcibly assign the only Mul node to another EP named 'OtherEp'.
std::shared_ptr<Model> model;
Status status = Model::Load(ORT_TSTR("testdata/mul_1.onnx"), model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK());

Graph& graph = model->MainGraph();
for (Node& node : graph.Nodes()) {
node.SetExecutionProviderType("OtherEp");
}

std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt");
if (std::filesystem::exists(log_file)) {
std::filesystem::remove(log_file);
}

// Call IExecutionProvider::GetCapability. The underlying OrtEp will try to take all nodes.
// Should not crash and should return an empty result.
{
logging::LoggingManager log_manager{std::make_unique<logging::FileSink>(log_file, false, false),
logging::Severity::kWARNING, false,
logging::LoggingManager::InstanceType::Temporal};
auto file_logger = log_manager.CreateLogger("FileLogger");
ep->SetLogger(file_logger.get()); // Make EP log to a file.

auto compute_capabilities = ep->GetCapability(GraphViewer(graph),
test_plugin_ep::MockKernelLookup{},
GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()),
nullptr);
EXPECT_TRUE(compute_capabilities.empty()); // No compute capabilities returned.
}

ASSERT_TRUE(std::filesystem::exists(log_file));
CheckStringInFile(log_file, "Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

if (std::filesystem::exists(log_file)) {
std::filesystem::remove(log_file);
}
}

// Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and
// nodes that are already assigned to another EP.
TEST(PluginExecutionProviderTest, GetCapability_ClaimSomeNodesAssignedToOtherEP) {
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp();
ort_ep->GetCapability = GetCapabilityTakeAllNodes;

// Load a model and forcibly assign only the first Add node to another EP named 'OtherEp'.
// Other nodes are unassigned and should be taken by the test plugin EP.
std::shared_ptr<Model> model;
Status status = Model::Load(ORT_TSTR("testdata/add_mul_add.onnx"), model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK());

Graph& graph = model->MainGraph();
for (Node& node : graph.Nodes()) {
if (node.Name() == "add_0") {
node.SetExecutionProviderType("OtherEp");
}
}

std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt");
if (std::filesystem::exists(log_file)) {
std::filesystem::remove(log_file);
}

// Call IExecutionProvider::GetCapability. The underlying OrtEp will try to take all nodes.
// Should not crash and should return a single compute capability with 2 out of the 3 nodes.
{
logging::LoggingManager log_manager{std::make_unique<logging::FileSink>(log_file, false, false),
logging::Severity::kWARNING, false,
logging::LoggingManager::InstanceType::Temporal};
auto file_logger = log_manager.CreateLogger("FileLogger");
ep->SetLogger(file_logger.get()); // Make EP log to a file.

auto compute_capabilities = ep->GetCapability(GraphViewer(graph),
test_plugin_ep::MockKernelLookup{},
GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()),
nullptr);

EXPECT_EQ(compute_capabilities.size(), 1);
EXPECT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), 2);
}

ASSERT_TRUE(std::filesystem::exists(log_file));
CheckStringInFile(log_file, "Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

if (std::filesystem::exists(log_file)) {
std::filesystem::remove(log_file);
}
}
} // namespace onnxruntime::test
Loading