Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/onnxruntime/core/graph/indexed_sub_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct IndexedSubGraph {
std::string domain; ///< Domain of customized SubGraph/FunctionProto
int since_version; ///< Since version of customized SubGraph/FunctionProto.

ONNX_NAMESPACE::OperatorStatus status; ///< Status of customized SubGraph/FunctionProto.
ONNX_NAMESPACE::OperatorStatus status{ONNX_NAMESPACE::OperatorStatus::STABLE}; ///< Status of customized SubGraph/FunctionProto.

std::vector<std::string> inputs; ///< Inputs of customized SubGraph/FunctionProto.
std::vector<std::string> outputs; ///< Outputs of customized SubGraph/FunctionProto.
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5829,7 +5829,7 @@ struct OrtApi {
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph,
ORT_API2_STATUS(Graph_GetNodes, _In_ const OrtGraph* graph,
_Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes);

/** \brief Get the parent node for the given graph, if any exists.
Expand Down
304 changes: 206 additions & 98 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h

Large diffs are not rendered by default.

247 changes: 222 additions & 25 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,16 @@ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
return type;
}

template <typename T>
inline OrtDeviceMemoryType MemoryInfoImpl<T>::GetDeviceMemoryType() const {
return GetApi().MemoryInfoGetDeviceMemType(this->p_);
}

template <typename T>
inline uint32_t MemoryInfoImpl<T>::GetVendorId() const {
return GetApi().MemoryInfoGetVendorId(this->p_);
}

template <typename T>
template <typename U>
inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
Expand All @@ -316,6 +326,12 @@ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, O
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
}

inline MemoryInfo::MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id,
OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type) {
ThrowOnError(GetApi().CreateMemoryInfo_V2(name, device_type, vendor_id, device_id, mem_type, alignment,
allocator_type, &this->p_));
}

namespace detail {
template <typename T>
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
Expand Down Expand Up @@ -404,33 +420,19 @@ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding

inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
std::vector<Value> result;
size_t owned = 0;
size_t output_count = 0;
// Lambda to release the buffer when no longer needed and
// make sure that we destroy all instances on exception
auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
if (buffer) {
while (owned < output_count) {
auto* p = buffer + owned++;
GetApi().ReleaseValue(*p);
}
allocator->Free(allocator, buffer);
}
};
using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;

OrtValue** output_buffer = nullptr;
ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
if (output_count == 0) {
return result;
}

Ptr buffer_g(output_buffer, free_fn);
std::unique_ptr<void, AllocatedFree> buffer_g(output_buffer, AllocatedFree(allocator));

result.reserve(output_count);
for (size_t i = 0; i < output_count; ++i) {
result.emplace_back(output_buffer[i]);
++owned;
}
return result;
}
Expand All @@ -446,6 +448,18 @@ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
}

inline ArenaCfg::ArenaCfg(const std::unordered_map<std::string, size_t>& arena_config) {
std::vector<const char*> keys;
std::vector<size_t> values;
keys.reserve(arena_config.size());
values.reserve(arena_config.size());
for (const auto& kv : arena_config) {
keys.push_back(kv.first.c_str());
values.push_back(kv.second);
}
ThrowOnError(GetApi().CreateArenaCfgV2(keys.data(), values.data(), arena_config.size(), &p_));
}

inline ThreadingOptions::ThreadingOptions() {
ThrowOnError(GetApi().CreateThreadingOptions(&p_));
}
Expand Down Expand Up @@ -485,6 +499,78 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom
return *this;
}

inline TensorRTProviderOptions::TensorRTProviderOptions() {
ThrowOnError(GetApi().CreateTensorRTProviderOptions(&this->p_));
}

inline void TensorRTProviderOptions::Update(const std::unordered_map<std::string, std::string>& options) {
std::vector<const char*> keys;
std::vector<const char*> values;
keys.reserve(options.size());
values.reserve(options.size());
for (const auto& kv : options) {
keys.push_back(kv.first.c_str());
values.push_back(kv.second.c_str());
}
ThrowOnError(GetApi().UpdateTensorRTProviderOptions(p_, keys.data(), values.data(), options.size()));
}

inline void TensorRTProviderOptions::UpdateWithValue(const char* key, void* value) {
ThrowOnError(GetApi().UpdateTensorRTProviderOptionsWithValue(p_, key, value));
}

inline void* TensorRTProviderOptions::GetOptionByName(const char* name) const {
void* value = nullptr;
ThrowOnError(GetApi().GetTensorRTProviderOptionsByName(p_, name, &value));
return value;
}

inline std::string TensorRTProviderOptions::GetTensorRTProviderOptionsAsString() const {
AllocatorWithDefaultOptions allocator;
char* options_str = nullptr;
ThrowOnError(GetApi().GetTensorRTProviderOptionsAsString(p_, allocator, &options_str));
std::unique_ptr<void, detail::AllocatedFree> options_str_g(options_str, detail::AllocatedFree(allocator));
return std::string(options_str);
}

inline CUDAProviderOptions::CUDAProviderOptions() {
ThrowOnError(GetApi().CreateCUDAProviderOptions(&this->p_));
}

inline void CUDAProviderOptions::Update(const std::unordered_map<std::string, std::string>& options) {
std::vector<const char*> keys;
std::vector<const char*> values;
keys.reserve(options.size());
values.reserve(options.size());
for (const auto& kv : options) {
keys.push_back(kv.first.c_str());
values.push_back(kv.second.c_str());
}
ThrowOnError(GetApi().UpdateCUDAProviderOptions(p_, keys.data(), values.data(), options.size()));
}

inline std::string CUDAProviderOptions::GetCUDAProviderOptionsAsString() const {
AllocatorWithDefaultOptions allocator;
char* options_str = nullptr;
ThrowOnError(GetApi().GetCUDAProviderOptionsAsString(p_, allocator, &options_str));
std::unique_ptr<void, detail::AllocatedFree> options_str_g(options_str, detail::AllocatedFree(allocator));
return std::string(options_str);
}

inline void CUDAProviderOptions::UpdateWithValue(const char* key, void* value) {
ThrowOnError(GetApi().UpdateCUDAProviderOptionsWithValue(p_, key, value));
}

inline void* CUDAProviderOptions::GetOptionByName(const char* name) const {
void* value = nullptr;
ThrowOnError(GetApi().GetCUDAProviderOptionsByName(p_, name, &value));
return value;
}

inline PrepackedWeightsContainer::PrepackedWeightsContainer() {
ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_));
}

namespace detail {
template <typename T>
inline const char* KeyValuePairsImpl<T>::GetValue(const char* key) const {
Expand Down Expand Up @@ -547,6 +633,10 @@ inline void KeyValuePairs::Remove(const char* key) {
GetApi().RemoveKeyValuePair(this->p_, key);
}

inline void* SyncStream::GetHandle() const {
return GetApi().SyncStream_GetHandle(this->p_);
}

namespace detail {
template <typename T>
inline OrtHardwareDeviceType HardwareDeviceImpl<T>::Type() const {
Expand Down Expand Up @@ -597,6 +687,19 @@ template <typename T>
inline ConstHardwareDevice EpDeviceImpl<T>::Device() const {
return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_));
}

template <typename T>
inline ConstMemoryInfo EpDeviceImpl<T>::GetMemoryInfo(OrtDeviceMemoryType memory_type) const {
const auto* mem_info = GetApi().EpDevice_MemoryInfo(this->p_, memory_type);
return ConstMemoryInfo{mem_info};
}

template <typename T>
inline SyncStream EpDeviceImpl<T>::CreateSyncStream(ConstKeyValuePairs stream_options) const {
OrtSyncStream* stream = nullptr;
ThrowOnError(GetApi().CreateSyncStreamForEpDevice(this->p_, stream_options, &stream));
return SyncStream{stream};
}
} // namespace detail

inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
Expand Down Expand Up @@ -676,6 +779,16 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type,
return *this;
}

inline Env& Env::RegisterAllocator(OrtAllocator* allocator) {
ThrowOnError(GetApi().RegisterAllocator(p_, allocator));
return *this;
}

inline Env& Env::UnregisterAllocator(const OrtMemoryInfo* mem_info) {
ThrowOnError(GetApi().UnregisterAllocator(p_, mem_info));
return *this;
}

inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name,
const std::basic_string<ORTCHAR_T>& path) {
ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str()));
Expand Down Expand Up @@ -703,6 +816,41 @@ inline std::vector<ConstEpDevice> Env::GetEpDevices() const {
return devices;
}

inline Status Env::CopyTensors(const std::vector<Value>& src_tensors,
const std::vector<Value>& dst_tensors,
OrtSyncStream* stream) const {
if (src_tensors.size() != dst_tensors.size()) {
return Status("Source and destination tensor vectors must have the same size", ORT_INVALID_ARGUMENT);
}
if (src_tensors.empty()) {
return Status();
}

const OrtValue* const* src_tensors_ptr = reinterpret_cast<const OrtValue* const*>(src_tensors.data());
OrtValue* const* dst_tensors_ptr = reinterpret_cast<OrtValue* const*>(dst_tensors.data());
OrtStatus* status = GetApi().CopyTensors(p_, src_tensors_ptr, dst_tensors_ptr, stream, src_tensors.size());
return Status(status);
}

inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type,
OrtAllocatorType allocator_type,
const OrtKeyValuePairs* allocator_options) {
OrtAllocator* p;
ThrowOnError(GetApi().CreateSharedAllocator(p_, ep_device, mem_type, allocator_type, allocator_options, &p));
return UnownedAllocator{p};
}

inline UnownedAllocator Env::GetSharedAllocator(const OrtMemoryInfo* mem_info) {
OrtAllocator* p;
ThrowOnError(GetApi().GetSharedAllocator(p_, mem_info, &p));
return UnownedAllocator{p};
}

inline void Env::ReleaseSharedAllocator(const OrtEpDevice* ep_device,
OrtDeviceMemoryType mem_type) {
ThrowOnError(GetApi().ReleaseSharedAllocator(p_, ep_device, mem_type));
}

inline CustomOpDomain::CustomOpDomain(const char* domain) {
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
}
Expand Down Expand Up @@ -1056,6 +1204,12 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializersFrom
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CPU(int use_arena) {
ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena));
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
Expand Down Expand Up @@ -1298,9 +1452,9 @@ inline std::vector<std::string> ConstSessionImpl<T>::GetInputNames() const {
input_names.reserve(num_inputs);

for (size_t i = 0; i < num_inputs; ++i) {
char* name = nullptr;
char* name;
ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name));
input_names.push_back(name);
input_names.emplace_back(name);
allocator.Free(name);
}

Expand All @@ -1316,9 +1470,9 @@ inline std::vector<std::string> ConstSessionImpl<T>::GetOutputNames() const {
output_names.reserve(num_inputs);

for (size_t i = 0; i < num_inputs; ++i) {
char* name = nullptr;
char* name;
ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name));
output_names.push_back(name);
output_names.emplace_back(name);
allocator.Free(name);
}

Expand All @@ -1334,14 +1488,44 @@ inline std::vector<std::string> ConstSessionImpl<T>::GetOverridableInitializerNa
initializer_names.reserve(num_initializers);

for (size_t i = 0; i < num_initializers; ++i) {
char* name = nullptr;
char* name;
ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name));
initializer_names.push_back(name);
initializer_names.emplace_back(name);
}

return initializer_names;
}

template <typename T>
inline std::vector<ConstMemoryInfo> ConstSessionImpl<T>::GetMemoryInfoForInputs() const {
static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*),
"ConstMemoryInfo must be compatible with OrtMemoryInfo*");

auto num_inputs = GetInputCount();
std::vector<ConstMemoryInfo> mem_infos;
mem_infos.resize(num_inputs);

ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
reinterpret_cast<const OrtMemoryInfo**>(&mem_infos[0]),
num_inputs));

return mem_infos;
}

template <typename T>
inline std::vector<ConstMemoryInfo> ConstSessionImpl<T>::GetMemoryInfoForOutputs() const {
static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*),
"ConstMemoryInfo must be compatible with OrtMemoryInfo*");

auto num_outputs = GetOutputCount();
std::vector<ConstMemoryInfo> mem_infos;
mem_infos.resize(num_outputs);

ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, reinterpret_cast<const OrtMemoryInfo**>(&mem_infos[0]),
num_outputs));
return mem_infos;
}

template <typename T>
inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
char* out;
Expand All @@ -1363,6 +1547,19 @@ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllo
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

template <typename T>
inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForInputs() const {
auto num_inputs = GetInputCount();
std::vector<ConstEpDevice> input_devices;
input_devices.resize(num_inputs);

ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
reinterpret_cast<const OrtEpDevice**>(&input_devices[0]),
num_inputs));

return input_devices;
}

template <typename T>
inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
uint64_t out;
Expand Down Expand Up @@ -1857,15 +2054,15 @@ inline size_t ConstValueImpl<T>::GetTensorSizeInBytes() const {
template <typename T>
template <typename R>
inline const R* ConstValueImpl<T>::GetTensorData() const {
R* out;
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
const R* out;
ThrowOnError(GetApi().GetTensorData(this->p_, reinterpret_cast<const void**>(&out)));
return out;
}

template <typename T>
inline const void* ConstValueImpl<T>::GetTensorRawData() const {
void* out;
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
const void* out;
ThrowOnError(GetApi().GetTensorData(this->p_, &out));
return out;
}

Expand Down
Loading
Loading