Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/migraphx/migraphx_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const {
int current_device;
auto hip_err = hipGetDevice(&current_device);
if (hip_err == hipSuccess) {
ORT_ENFORCE(current_device == Info().id);
ORT_ENFORCE(current_device == Info().device.Id());
}
#endif
}
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ struct MIGraphX_Provider : Provider {
const OrtSessionOptions& session_options,
const OrtLogger& logger,
std::unique_ptr<IExecutionProvider>& ep) override {
ORT_UNUSED_PARAMETER(num_devices);
const ConfigOptions* config_options = &session_options.GetConfigOptions();

std::array<const void*, 2> configs_array = {&provider_options, config_options};
const void* arg = reinterpret_cast<const void*>(&configs_array);
auto ep_factory = CreateExecutionProviderFactory(&provider_options);
ep = ep_factory->CreateProvider(session_options, logger);

Expand All @@ -181,7 +181,7 @@ struct MigraphXEpFactory : OrtEpFactory {
const char* ep_name,
OrtHardwareDeviceType hw_type,
const OrtLogger& default_logger_in)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} {
: ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
Expand All @@ -191,12 +191,12 @@ struct MigraphXEpFactory : OrtEpFactory {

// Returns the name for the EP. Each unique factory configuration must have a unique name.
// Ex: a factory that supports NPU should have a different than a factory that supports GPU.
static const char* GetNameImpl(const OrtEpFactory* this_ptr) {
static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->ep_name.c_str();
}

static const char* GetVendorImpl(const OrtEpFactory* this_ptr) {
static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->vendor.c_str();
}
Expand All @@ -212,7 +212,7 @@ struct MigraphXEpFactory : OrtEpFactory {
size_t num_devices,
OrtEpDevice** ep_devices,
size_t max_ep_devices,
size_t* p_num_ep_devices) {
size_t* p_num_ep_devices) noexcept {
size_t& num_ep_devices = *p_num_ep_devices;
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

Expand All @@ -237,11 +237,11 @@ struct MigraphXEpFactory : OrtEpFactory {
_In_ size_t /*num_devices*/,
_In_ const OrtSessionOptions* /*session_options*/,
_In_ const OrtLogger* /*logger*/,
_Out_ OrtEp** /*ep*/) {
_Out_ OrtEp** /*ep*/) noexcept {
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method.");
}

static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) {
static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept {
// no-op as we never create an EP here.
}

Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}

static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
// Access the underlying InferenceSession.
const OrtSession* ort_session = session;
const InferenceSession* s = reinterpret_cast<const InferenceSession*>(ort_session);
bool has_ep = false;

for (const auto& provider : s->GetRegisteredProviderTypes()) {
if (provider == ep_name) {
has_ep = true;
break;
}
}
return has_ep;
}

TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {
PathString model_name = ORT_TSTR("migraphx_basic_test.onnx");

Expand Down