Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix compilation after cherry-pick
  • Loading branch information
apwojcik committed Jul 23, 2025
commit fc7a6c61d0f3a7976cf86b93c98290fbfac5414c
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/migraphx/migraphx_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const {
int current_device;
auto hip_err = hipGetDevice(&current_device);
if (hip_err == hipSuccess) {
ORT_ENFORCE(current_device == Info().id);
ORT_ENFORCE(current_device == Info().device.Id());
}
#endif
}
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ struct MIGraphX_Provider : Provider {
const OrtSessionOptions& session_options,
const OrtLogger& logger,
std::unique_ptr<IExecutionProvider>& ep) override {
ORT_UNUSED_PARAMETER(num_devices);
const ConfigOptions* config_options = &session_options.GetConfigOptions();

std::array<const void*, 2> configs_array = {&provider_options, config_options};
const void* arg = reinterpret_cast<const void*>(&configs_array);
auto ep_factory = CreateExecutionProviderFactory(&provider_options);
ep = ep_factory->CreateProvider(session_options, logger);

Expand All @@ -181,7 +181,7 @@ struct MigraphXEpFactory : OrtEpFactory {
const char* ep_name,
OrtHardwareDeviceType hw_type,
const OrtLogger& default_logger_in)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} {
: ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
Expand All @@ -191,12 +191,12 @@ struct MigraphXEpFactory : OrtEpFactory {

// Returns the name for the EP. Each unique factory configuration must have a unique name.
// Ex: a factory that supports NPU should have a different than a factory that supports GPU.
static const char* GetNameImpl(const OrtEpFactory* this_ptr) {
static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->ep_name.c_str();
}

static const char* GetVendorImpl(const OrtEpFactory* this_ptr) {
static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept{
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->vendor.c_str();
}
Expand All @@ -212,7 +212,7 @@ struct MigraphXEpFactory : OrtEpFactory {
size_t num_devices,
OrtEpDevice** ep_devices,
size_t max_ep_devices,
size_t* p_num_ep_devices) {
size_t* p_num_ep_devices) noexcept{
size_t& num_ep_devices = *p_num_ep_devices;
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

Expand All @@ -237,11 +237,11 @@ struct MigraphXEpFactory : OrtEpFactory {
_In_ size_t /*num_devices*/,
_In_ const OrtSessionOptions* /*session_options*/,
_In_ const OrtLogger* /*logger*/,
_Out_ OrtEp** /*ep*/) {
_Out_ OrtEp** /*ep*/) noexcept{
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method.");
}

static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) {
static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept{
// no-op as we never create an EP here.
}

Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}

static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
// Access the underlying InferenceSession.
const OrtSession* ort_session = session;
const InferenceSession* s = reinterpret_cast<const InferenceSession*>(ort_session);
bool has_ep = false;

for (const auto& provider : s->GetRegisteredProviderTypes()) {
if (provider == ep_name) {
has_ep = true;
break;
}
}
return has_ep;
}

TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {
PathString model_name = ORT_TSTR("migraphx_basic_test.onnx");

Expand Down