Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ struct OrtDevice {
return alignment < other.alignment;
}

bool EqualIgnoringAlignment(const OrtDevice& other) const {
return device_type == other.device_type &&
memory_type == other.memory_type &&
vendor_id == other.vendor_id &&
device_id == other.device_id;
}

private:
// Device type.
int32_t device_type : 8;
Expand Down
11 changes: 10 additions & 1 deletion include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class Environment {
return shared_allocators_;
}

/**
* Returns an AllocatorPtr for a shared IAllocator based allocator if it matches the memory info.
* The OrtMemoryInfo name and whether it's an arena or device allocator is ignored in the lookup, as is the
* alignment.
* The user calling this function is not expected to know the alignment, and we expect the allocator instance to be
* created with a valid alignment for the device.
*/
AllocatorPtr GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const;

/**
* Removes registered allocator that was previously registered for sharing between multiple sessions.
*/
Expand Down Expand Up @@ -171,7 +180,7 @@ class Environment {
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};

std::mutex mutex_;
mutable std::mutex mutex_;

// shared allocators from various sources.
// CreateAndRegisterAllocator[V2]: IAllocator allocators created by ORT
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,10 @@
}
*/

// guard against bad device discovery. max devices we expect to add is num_cuda_devices. if we're attempting
// to add more than that we have duplicates in the `devices` array.
max_ep_devices = std::min(max_ep_devices, static_cast<size_t>(num_cuda_devices));

Check warning on line 739 in onnxruntime/core/providers/cuda/cuda_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/cuda_provider_factory.cc:739: Add #include <algorithm> for min [build/include_what_you_use] [4]

int16_t device_id = 0;
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
Expand Down
34 changes: 27 additions & 7 deletions onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ ProviderInfo_CUDA& GetProviderInfo_CUDA();
#endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)

namespace {
// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison
// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison.
static bool AreOrtMemoryInfosEquivalent(
const OrtMemoryInfo& left, const OrtMemoryInfo& right,
bool match_name = true) {
bool match_name = true,
bool ignore_alignment = false) {
return left.mem_type == right.mem_type &&
left.device == right.device &&
(ignore_alignment ? left.device.EqualIgnoringAlignment(right.device) : left.device == right.device) &&
(!match_name || strcmp(left.name, right.name) == 0);
}

std::vector<AllocatorPtr>::const_iterator FindExistingAllocator(const std::vector<AllocatorPtr>& allocators,
const OrtMemoryInfo& mem_info,
bool match_name = true) {
bool match_name = true,
bool ignore_alignment = false) {
auto ite = std::find_if(std::begin(allocators),
std::end(allocators),
[&mem_info, match_name](const AllocatorPtr& alloc_ptr) {
[&mem_info, match_name, ignore_alignment](const AllocatorPtr& alloc_ptr) {
// We want to do the equality checking of 2 OrtMemoryInfos sans the OrtAllocatorType field.
// This is because we want to avoid registering two allocators for the same device that just
// differ on OrtAllocatorType.
Expand All @@ -96,7 +98,8 @@ std::vector<AllocatorPtr>::const_iterator FindExistingAllocator(const std::vecto
// OrtDeviceAllocator (which is the only accepted value while registering a custom allocator).
// If we allowed this, it could potentially cause a lot of confusion as to which shared allocator
// to use for that device and we want to avoid having any ugly logic around this.
return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, match_name);
return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info,
match_name, ignore_alignment);
});

return ite;
Expand Down Expand Up @@ -428,8 +431,25 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ
}

Environment::~Environment() {
// need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed
// need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed.
// this is because any entry in shared_allocators_ wrapping an OrtAllocator from a plugin EP owns the OrtAllocator
// instance and will call Release on it. If the plugin EP has been freed the Release will fail.
shared_allocators_.clear();

#if !defined(ORT_MINIMAL_BUILD)
// unregister any remaining EP libraries so they're cleaned up in a determistic way.
while (!ep_libraries_.empty()) {
auto it = ep_libraries_.begin();
ORT_IGNORE_RETURN_VALUE(UnregisterExecutionProviderLibrary(it->first));
}
#endif
}

AllocatorPtr Environment::GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const {
std::lock_guard<std::mutex> lock{mutex_};

auto it = FindExistingAllocator(shared_allocators_, mem_info, /*match_name*/ false, /*ignore_alignment*/ true);
return it != shared_allocators_.end() ? *it : nullptr;
}

Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator) {
Expand Down
71 changes: 41 additions & 30 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import onnxruntime


def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
def get_ort_device_type(device_type: str) -> int:
if device_type == "cuda":
return C.OrtDevice.cuda()
elif device_type == "cann":
Expand All @@ -32,8 +32,10 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
return C.OrtDevice.dml()
elif device_type == "webgpu":
return C.OrtDevice.webgpu()
elif device_type == "ort":
return C.get_ort_device(device_index).device_type()
elif device_type == "gpu":
return C.OrtDevice.gpu()
elif device_type == "npu":
return C.OrtDevice.npu()
else:
raise Exception("Unsupported device type: " + device_type)

Expand Down Expand Up @@ -765,7 +767,7 @@ def bind_input(self, name, device_type, device_id, element_type, shape, buffer_p
self._iobinding.bind_input(
name,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
get_ort_device_type(device_type),
C.OrtDevice.default_memory(),
device_id,
),
Expand Down Expand Up @@ -812,7 +814,7 @@ def bind_output(
self._iobinding.bind_output(
name,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
get_ort_device_type(device_type),
C.OrtDevice.default_memory(),
device_id,
),
Expand All @@ -823,7 +825,7 @@ def bind_output(
self._iobinding.bind_output(
name,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
get_ort_device_type(device_type),
C.OrtDevice.default_memory(),
device_id,
),
Expand Down Expand Up @@ -889,26 +891,23 @@ def _get_c_value(self) -> C.OrtValue:
return self._ortvalue

@classmethod
def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue:
def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0, vendor_id=-1) -> OrtValue:
"""
Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu

:param numpy_obj: The Numpy object to construct the OrtValue from
:param device_type: e.g. cpu, cuda, cann, cpu by default
:param device_id: device id, e.g. 0
:param vendor_id: The device's PCI vendor id. If provided, the device_type should be "gpu" or "npu".
"""
# Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
# is backed directly by the data buffer of the numpy object and so the numpy object
# must be around until this OrtValue instance is around
return cls(
C.OrtValue.ortvalue_from_numpy(
numpy_obj,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
C.OrtDevice.default_memory(),
device_id,
),
OrtDevice.make(device_type, device_id, vendor_id)._get_c_device(),
),
numpy_obj if device_type.lower() == "cpu" else None,
)
Expand All @@ -929,7 +928,7 @@ def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_ty

@classmethod
def ortvalue_from_shape_and_type(
cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0
cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0, vendor_id: int = -1
) -> OrtValue:
"""
Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
Expand All @@ -938,31 +937,27 @@ def ortvalue_from_shape_and_type(
:param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16).
:param device_type: e.g. cpu, cuda, cann, cpu by default
:param device_id: device id, e.g. 0
:param vendor_id: If provided the device type should be "gpu" or "npu".
"""

device = OrtDevice.make(device_type, device_id, vendor_id)._get_c_device()

# Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
# This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
if isinstance(element_type, int):
return cls(
C.OrtValue.ortvalue_from_shape_and_onnx_type(
shape,
element_type,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
C.OrtDevice.default_memory(),
device_id,
),
device,
)
)

return cls(
C.OrtValue.ortvalue_from_shape_and_type(
shape,
element_type,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
C.OrtDevice.default_memory(),
device_id,
),
device,
)
)

Expand Down Expand Up @@ -1085,21 +1080,37 @@ def _get_c_device(self):
return self._ort_device

@staticmethod
def make(ort_device_name, device_id):
return OrtDevice(
C.OrtDevice(
get_ort_device_type(ort_device_name, device_id),
C.OrtDevice.default_memory(),
device_id,
def make(ort_device_name, device_id, vendor_id=-1):
if vendor_id < 0:
# backwards compatibility with predefined OrtDevice names
return OrtDevice(
C.OrtDevice(
get_ort_device_type(ort_device_name),
C.OrtDevice.default_memory(),
device_id,
)
)
else:
# generic. use GPU or NPU for ort_device_name and provide a vendor id.
# vendor id of 0 is valid in some cases (e.g. webgpu is generic and does not have a vendor id)
return OrtDevice(
C.OrtDevice(
get_ort_device_type(ort_device_name),
C.OrtDevice.default_memory(),
vendor_id,
device_id,
)
)
)

def device_id(self):
return self._ort_device.device_id()

def device_type(self):
return self._ort_device.device_type()

def device_vendor_id(self):
return self._ort_device.vendor_id()


class SparseTensor:
"""
Expand Down
Loading
Loading