Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,11 @@ if (onnxruntime_USE_MIGRAPHX)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MIGRAPHX=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES migraphx)
endif()

if (onnxruntime_USE_MIGRAPHX_INTERFACE AND (NOT onnxruntime_USE_MIGRAPHX))
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MIGRAPHX_PROVIDER_INTERFACE=1)
endif()

if (onnxruntime_USE_ARMNN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_ARMNN=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES armnn)
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,11 @@ class PosixThread : public EnvThread {
} else {
errno = ret;
auto [err_no, err_msg] = GetErrnoInfo();
#if !defined(USE_MIGRAPHX)
LOGS_DEFAULT(ERROR) << "pthread_setaffinity_np failed for thread: " << syscall(SYS_gettid)
<< ", index: " << p->index
<< ", mask: " << *p->affinity
<< ", error code: " << err_no << " error msg: " << err_msg
<< ". Specify the number of threads explicitly so the affinity is not set.";
#endif
}
}
#endif
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@
return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options);
}

Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/,
const OrtKeyValuePairs* const* /*ep_metadata*/,
size_t num_devices,
ProviderOptions& provider_options,
const OrtSessionOptions& session_options,
const OrtLogger& logger,
std::unique_ptr<IExecutionProvider>& ep) override {

const ConfigOptions* config_options = &session_options.GetConfigOptions();

std::array<const void*, 2> configs_array = {&provider_options, config_options};
const void* arg = reinterpret_cast<const void*>(&configs_array);
auto ep_factory = CreateExecutionProviderFactory(&provider_options);
ep = ep_factory->CreateProvider(session_options, logger);

return Status::OK();
}

void Initialize() override {
InitializeRegistry();
}
Expand All @@ -156,9 +174,120 @@

} // namespace onnxruntime

#include "core/framework/error_code_helper.h"

// OrtEpApi infrastructure to be able to use the MigraphX/AMDGPU EP as an OrtEpFactory for auto EP selection.
struct MigraphXEpFactory : OrtEpFactory {
MigraphXEpFactory(const OrtApi& ort_api_in,
const char* ep_name,
OrtHardwareDeviceType hw_type,
const OrtLogger& default_logger_in)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
CreateEp = CreateEpImpl;
ReleaseEp = ReleaseEpImpl;
}

// Returns the name for the EP. Each unique factory configuration must have a unique name.
// Ex: a factory that supports NPU should have a different than a factory that supports GPU.
static const char* GetNameImpl(const OrtEpFactory* this_ptr) {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->ep_name.c_str();
}

static const char* GetVendorImpl(const OrtEpFactory* this_ptr) {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->vendor.c_str();
}

// Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.
// An EP created with this factory is expected to be able to execute a model with *all* supported
// hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among
// multiple different MigraphX backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only
// support one backend: gpu. To support a different backend, like npu, create a different factory instance
// that only supports NPU.
static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr,
const OrtHardwareDevice* const* devices,
size_t num_devices,
OrtEpDevice** ep_devices,
size_t max_ep_devices,
size_t* p_num_ep_devices) {
size_t& num_ep_devices = *p_num_ep_devices;
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type){
//factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) {
OrtKeyValuePairs* ep_options = nullptr;
factory->ort_api.CreateKeyValuePairs(&ep_options);
ORT_API_RETURN_IF_ERROR(
factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options,
&ep_devices[num_ep_devices++]));
}
}

return nullptr;
}

static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/,
_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/,
_In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/,
_In_ size_t /*num_devices*/,
_In_ const OrtSessionOptions* /*session_options*/,
_In_ const OrtLogger* /*logger*/,
_Out_ OrtEp** /*ep*/) {
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method.");
}

static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) {
// no-op as we never create an EP here.
}

const OrtApi& ort_api;
const OrtLogger& default_logger;
const std::string ep_name;
const std::string vendor{"AMD"};

Check warning on line 252 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:252: Add #include <string> for string [build/include_what_you_use] [4]

const uint32_t vendor_id{0x1002};
const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice
};

extern "C" {
//
// Public symbols
//
OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base,
const OrtLogger* default_logger,
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);

// Factory could use registration_name or define its own EP name.
auto factory_gpu = std::make_unique<MigraphXEpFactory>(*ort_api,

Check warning on line 268 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:268: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
onnxruntime::kMIGraphXExecutionProvider,
OrtHardwareDeviceType_GPU,
*default_logger);

if (max_factories < 1) {
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,
"Not enough space to return EP factory. Need at least one.");
}

factories[0] = factory_gpu.release();
*num_factories = 1;

return nullptr;
}

OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
delete static_cast<MigraphXEpFactory*>(factory);
return nullptr;
}

ORT_API(onnxruntime::Provider*, GetProvider) {
return &onnxruntime::g_provider;
}

}
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/migraphx/symbols.def
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
EXPORTS
GetProvider
CreateEpFactories
ReleaseEpFactory
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/provider_factory_creators.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include "core/providers/dnnl/dnnl_provider_factory_creator.h"
#endif

#if defined(USE_MIGRAPHX)
#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE)
#include "core/providers/migraphx/migraphx_provider_factory_creator.h"
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,15 +361,13 @@ std::unique_ptr<IDataTransfer> CreateGPUDataTransfer() {
}
#endif

#ifdef USE_MIGRAPHX
std::unique_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, const char* name) {
return g_host->CreateMIGraphXAllocator(device_id, name);
}

std::unique_ptr<IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) {
return g_host->CreateMIGraphXPinnedAllocator(device_id, name);
}
#endif

std::string GetEnvironmentVar(const std::string& var_name) {
return g_host->GetEnvironmentVar(var_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,8 @@ struct ProviderHost {
virtual Status CudaCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0;
virtual void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0;

#ifdef USE_MIGRAPHX
virtual std::unique_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0;
virtual std::unique_ptr<IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0;
#endif

#ifdef USE_ROCM
virtual std::unique_ptr<IAllocator> CreateROCMAllocator(int16_t device_id, const char* name) = 0;
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,8 @@ struct ProviderHostImpl : ProviderHost {
Status CudaCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { return GetProviderInfo_CUDA().CudaCall_false(retCode, exprString, libName, successCode, msg, file, line); }
void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { GetProviderInfo_CUDA().CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); }

#ifdef USE_MIGRAPHX
std::unique_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(device_id, name); }
std::unique_ptr<IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXPinnedAllocator(device_id, name); }
#endif

std::unique_ptr<IDataTransfer> CreateGPUDataTransfer() override { return GetProviderInfo_CUDA().CreateGPUDataTransfer(); }

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
<< "TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met.";
#endif
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE)
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_state_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct OrtStatus {
#define BACKEND_DNNL ""
#endif

#if USE_MIGRAPHX
#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE)
#define BACKEND_MIGRAPHX "-MIGRAPHX"
#else
#define BACKEND_MIGRAPHX ""
Expand Down Expand Up @@ -132,7 +132,7 @@ struct OrtStatus {
#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE)
#include "core/providers/nv_tensorrt_rtx/nv_provider_factory.h"
#endif
#ifdef USE_MIGRAPHX
#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE)
#include "core/providers/migraphx/migraphx_provider_factory.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#endif
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,30 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}

TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {
PathString model_name = ORT_TSTR("migraphx_basic_test.onnx");

onnxruntime::Model model("test", false, DefaultLoggingManager().DefaultLogger());
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model, dims);

auto status = onnxruntime::Model::Save(model, model_name);
ASSERT_TRUE(status.IsOK());

auto env = Ort::Env();
env.UpdateEnvWithCustomLogLevel(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING);

{
env.RegisterExecutionProviderLibrary(kMIGraphXExecutionProvider, ORT_TSTR("onnxruntime_providers_migraphx.dll"));

Ort::SessionOptions so;
so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU);
Ort::Session session_object(env, model_name.c_str(), so);
EXPECT_TRUE(SessionHasEp(session_object, kMIGraphXExecutionProvider));
}

env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider);
}

} // namespace test
} // namespace onnxruntime
1 change: 1 addition & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def generate_build_tree(
"-Donnxruntime_USE_OPENVINO_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"),
"-Donnxruntime_USE_VITISAI_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"),
"-Donnxruntime_USE_QNN_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"),
"-Donnxruntime_USE_MIGRAPHX_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"),
# set vars for migraphx
"-Donnxruntime_USE_MIGRAPHX=" + ("ON" if args.use_migraphx else "OFF"),
"-Donnxruntime_DISABLE_CONTRIB_OPS=" + ("ON" if args.disable_contrib_ops else "OFF"),
Expand Down
Loading