Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reenable device check and change factory creation
  • Loading branch information
owenzhangzhengzhong authored and wonchung-microsoft committed Jul 21, 2025
commit c93b3493a5a9a4cb7eca2efc30db1bea6e4ff246
14 changes: 8 additions & 6 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
return std::make_shared<MIGraphXProviderFactory>(info);
}

/*
//TODO: Interface change might require changes in other parts of win-onnxruntime?
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* param) {
if (param == nullptr) {
Expand All @@ -97,8 +98,9 @@
UpdateProviderOptions(&info, *provider_options);
return std::make_shared<MIGraphXProviderFactory>(info);
}
*/

/* std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* provider_options) override {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* provider_options) override {
auto& options = *reinterpret_cast<const OrtMIGraphXProviderOptions*>(provider_options);
MIGraphXExecutionProviderInfo info;
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
Expand All @@ -125,7 +127,7 @@
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(options.migraphx_arena_extend_strategy);
info.mem_limit = options.migraphx_mem_limit;
return std::make_shared<MIGraphXProviderFactory>(info);
}*/
}

void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options);
Expand Down Expand Up @@ -173,9 +175,9 @@
const OrtSessionOptions& session_options,
const OrtLogger& logger,
std::unique_ptr<IExecutionProvider>& ep) override {
if (num_devices != 1) {
return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "[MigraphX/AMDGPU EP] only supports one device.");
}
//if (num_devices != 1) {
// return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "[MigraphX/AMDGPU EP] only supports one device.");
//}

const ConfigOptions* config_options = &session_options.GetConfigOptions();

Expand Down Expand Up @@ -243,7 +245,7 @@

for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type || true){
if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type){
//factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) {
OrtKeyValuePairs* ep_options = nullptr;
factory->ort_api.CreateKeyValuePairs(&ep_options);
Expand Down Expand Up @@ -272,10 +274,10 @@

const OrtApi& ort_api;
const std::string ep_name;
const std::string vendor{"AMD"};

Check warning on line 277 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:277: Add #include <string> for string [build/include_what_you_use] [4]

// AMD vendor ID. Refer to the ACPI ID registry (search AMD): https://uefi.org/ACPI_ID_List
const uint32_t vendor_id{0x1022}; //TODO: set correct value for AMD GPU

Check warning on line 280 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:280: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice
};

Expand All @@ -288,7 +290,7 @@
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);

// Factory could use registration_name or define its own EP name.
auto factory_gpu = std::make_unique<MigraphXEpFactory>(*ort_api,

Check warning on line 293 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:293: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
onnxruntime::kMIGraphXExecutionProvider,
OrtHardwareDeviceType_GPU);

Expand Down
Loading