Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/migraphx/migraphx_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const {
int current_device;
auto hip_err = hipGetDevice(&current_device);
if (hip_err == hipSuccess) {
ORT_ENFORCE(current_device == Info().id);
ORT_ENFORCE(current_device == Info().device.Id());
}
#endif
}
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ struct MIGraphX_Provider : Provider {
const OrtSessionOptions& session_options,
const OrtLogger& logger,
std::unique_ptr<IExecutionProvider>& ep) override {
ORT_UNUSED_PARAMETER(num_devices);
const ConfigOptions* config_options = &session_options.GetConfigOptions();

std::array<const void*, 2> configs_array = {&provider_options, config_options};
const void* arg = reinterpret_cast<const void*>(&configs_array);
auto ep_factory = CreateExecutionProviderFactory(&provider_options);
ep = ep_factory->CreateProvider(session_options, logger);

Expand All @@ -181,7 +181,7 @@ struct MigraphXEpFactory : OrtEpFactory {
const char* ep_name,
OrtHardwareDeviceType hw_type,
const OrtLogger& default_logger_in)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} {
: ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
Expand All @@ -191,12 +191,12 @@ struct MigraphXEpFactory : OrtEpFactory {

// Returns the name for the EP. Each unique factory configuration must have a unique name.
// Ex: a factory that supports NPU should have a different than a factory that supports GPU.
static const char* GetNameImpl(const OrtEpFactory* this_ptr) {
static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->ep_name.c_str();
}

static const char* GetVendorImpl(const OrtEpFactory* this_ptr) {
static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->vendor.c_str();
}
Expand All @@ -212,7 +212,7 @@ struct MigraphXEpFactory : OrtEpFactory {
size_t num_devices,
OrtEpDevice** ep_devices,
size_t max_ep_devices,
size_t* p_num_ep_devices) {
size_t* p_num_ep_devices) noexcept {
size_t& num_ep_devices = *p_num_ep_devices;
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

Expand All @@ -237,11 +237,11 @@ struct MigraphXEpFactory : OrtEpFactory {
_In_ size_t /*num_devices*/,
_In_ const OrtSessionOptions* /*session_options*/,
_In_ const OrtLogger* /*logger*/,
_Out_ OrtEp** /*ep*/) {
_Out_ OrtEp** /*ep*/) noexcept {
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method.");
}

static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) {
static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept {
// no-op as we never create an EP here.
}

Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}

static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
// Access the underlying InferenceSession.
const OrtSession* ort_session = session;
const InferenceSession* s = reinterpret_cast<const InferenceSession*>(ort_session);
bool has_ep = false;

for (const auto& provider : s->GetRegisteredProviderTypes()) {
if (provider == ep_name) {
has_ep = true;
break;
}
}
return has_ep;
}

#if defined(WIN32)
// Tests autoEP feature to automatically select an EP that supports the GPU.
// Currently only works on Windows.
TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {
PathString model_name = ORT_TSTR("migraphx_basic_test.onnx");

Expand All @@ -212,6 +230,7 @@ TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {

env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider);
}
#endif

} // namespace test
} // namespace onnxruntime
Loading