Skip to content
Merged
Prev Previous commit
Next Next commit
Make shared allocator available for all devices in OrtValueFromShapeA…
…ndType
  • Loading branch information
skottmckay committed Jul 21, 2025
commit 42d8b0816b3cc6ee911b18fdad90cbdb69ff51dc
69 changes: 40 additions & 29 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,56 @@ std::unique_ptr<OrtValue> OrtValueFromShapeAndType(const std::vector<int64_t>& s

if (strcmp(GetDeviceName(device), CPU) == 0) {
allocator = GetAllocator();
} else if (strcmp(GetDeviceName(device), CUDA) == 0) {
} else {
#if !defined(ORT_MINIMAL_BUILD)
// if a plugin EP has been registered we can get a shared allocator from the environment.
// we use this as the fallback option.
allocator = GetSharedAllocator(device);
#endif

if (strcmp(GetDeviceName(device), CUDA) == 0) {
#ifdef USE_CUDA
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}

allocator = GetCudaAllocator(device.Id());
#elif !defined(ORT_MINIMAL_BUILD)
// if a plugin EP has been registered we can get a shared allocator from the environment
allocator = GetSharedAllocator(device);
allocator = GetCudaAllocator(device.Id());
#endif
if (!allocator) {
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
}
} else if (strcmp(GetDeviceName(device), HIP) == 0) {
if (!allocator) {
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
}
} else if (strcmp(GetDeviceName(device), HIP) == 0) {
#if USE_ROCM
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
allocator = GetRocmAllocator(device.Id());
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
allocator = GetRocmAllocator(device.Id());
#elif USE_MIGRAPHX
allocator = GetMIGraphXAllocator(device.Id());
allocator = GetMIGraphXAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the AMD device using this package of OnnxRuntime. "
"Please use the ROCm package of OnnxRuntime to use this feature.");
if (!allocator) {
throw std::runtime_error(
"Can't allocate memory on the AMD device using this package of OnnxRuntime. "
"Please use the ROCm package of OnnxRuntime to use this feature.");
}
#endif
} else if (strcmp(GetDeviceName(device), DML) == 0) {
} else if (strcmp(GetDeviceName(device), DML) == 0) {
#if USE_DML
allocator = GetDmlAllocator(device.Id());
allocator = GetDmlAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
"Please use the DirectML package of OnnxRuntime to use this feature.");
if (!allocator) {
throw std::runtime_error(
"Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
"Please use the DirectML package of OnnxRuntime to use this feature.");
}
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
} else {
if (!allocator) {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
}
}
}

auto ml_value = std::make_unique<OrtValue>();
Expand Down
Loading