Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions cmake/onnxruntime_java.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ if (WIN32)
if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime>)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_FILE_NAME:onnxruntime4j_jni>)
if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB))
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime_providers_shared>)
endif()
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime_providers_shared>)
if (onnxruntime_USE_CUDA)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_cuda> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime_providers_cuda>)
endif()
Expand Down Expand Up @@ -205,9 +203,7 @@ if (WIN32)
else()
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime>)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime4j_jni>)
if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB))
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime_providers_shared>)
endif()
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime_providers_shared>)
if (onnxruntime_USE_CUDA)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_cuda> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime_providers_cuda>)
endif()
Expand Down
4 changes: 4 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,10 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR})
endif()
if (WIN32)
set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $<IF:$<BOOL:${WIN32}>,$<TARGET_FILE_NAME:example_plugin_ep>,$<TARGET_LINKER_FILE_NAME:example_plugin_ep>>)
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:example_plugin_ep> ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME})
endif()

# delegate to gradle's test runner

Expand Down
18 changes: 16 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ final class OnnxRuntime {
private static final int ORT_API_VERSION_13 = 13;
// Post 1.13 builds of the ORT API
private static final int ORT_API_VERSION_14 = 14;
// Post 1.22 builds of the ORT API
private static final int ORT_API_VERSION_23 = 23;

// The initial release of the ORT training API.
private static final int ORT_TRAINING_API_VERSION_1 = 1;
Expand Down Expand Up @@ -103,6 +105,9 @@ final class OnnxRuntime {
/** The Training API handle. */
static long ortTrainingApiHandle;

/** The Compile API handle. */
static long ortCompileApiHandle;

/** Is training enabled in the native library */
static boolean trainingEnabled;

Expand Down Expand Up @@ -174,12 +179,13 @@ static synchronized void init() throws IOException {
}
load(ONNXRUNTIME_JNI_LIBRARY_NAME);

ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23);
if (ortApiHandle == 0L) {
throw new IllegalStateException(
"There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded");
}
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14);
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23);
ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle);
trainingEnabled = ortTrainingApiHandle != 0L;
providers = initialiseProviders(ortApiHandle);
version = initialiseVersion();
Expand Down Expand Up @@ -497,6 +503,14 @@ private static EnumSet<OrtProvider> initialiseProviders(long ortApiHandle) {
*/
private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber);

/**
* Get a reference to the compile API struct.
*
* @param apiHandle The ORT API struct pointer.
* @return A pointer to the compile API struct.
*/
private static native long initialiseCompileAPIBase(long apiHandle);

/**
* Gets the array of available providers.
*
Expand Down
82 changes: 81 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand All @@ -8,7 +8,11 @@
import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;

Expand Down Expand Up @@ -442,6 +446,48 @@ public static EnumSet<OrtProvider> getAvailableProviders() {
return OnnxRuntime.providers.clone();
}

/**
* Registers an execution provider library with this OrtEnvironment.
*
* @param registrationName The name to register the library with (used to remove it later with
* {@link #unregisterExecutionProviderLibrary(String)}).
* @param libraryPath The path to the library binary on disk.
* @throws OrtException If the library could not be registered.
*/
public void registerExecutionProviderLibrary(String registrationName, String libraryPath)
throws OrtException {
registerExecutionProviderLibrary(
OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath);
}

/**
* Unregisters an execution provider library from this OrtEnvironment.
*
* @param registrationName The name the library was registered under.
* @throws OrtException If the library could not be removed.
*/
public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException {
unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName);
}

/**
* Get the list of all execution provider and device combinations that are available.
*
* @see OrtSession.SessionOptions#addExecutionProvider(List, Map)
* @return The list of execution provider and device combinations.
* @throws OrtException If the devices could not be listed.
*/
public List<OrtEpDevice> getEpDevices() throws OrtException {
long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle);

List<OrtEpDevice> devicesList = new ArrayList<>();
for (long deviceHandle : deviceHandles) {
devicesList.add(new OrtEpDevice(deviceHandle));
}

return Collections.unmodifiableList(devicesList);
}

/**
* Creates the native object.
*
Expand Down Expand Up @@ -476,6 +522,40 @@ private static native long createHandle(
*/
private static native long getDefaultAllocator(long apiHandle) throws OrtException;

/**
* Registers the specified execution provider with this OrtEnvironment.
*
* @param apiHandle The API handle.
* @param nativeHandle The OrtEnvironment handle.
* @param registrationName The name of the execution provider.
* @param libraryPath The path to the execution provider binary.
* @throws OrtException If the registration failed.
*/
private static native void registerExecutionProviderLibrary(
long apiHandle, long nativeHandle, String registrationName, String libraryPath)
throws OrtException;

/**
* Removes the specified execution provider from this OrtEnvironment.
*
* @param apiHandle The API handle.
* @param nativeHandle The OrtEnvironment handle.
* @param registrationName The name of the execution provider.
* @throws OrtException If the removal failed.
*/
private static native void unregisterExecutionProviderLibrary(
long apiHandle, long nativeHandle, String registrationName) throws OrtException;

/**
* Gets handles for the EP device tuples available in this OrtEnvironment.
*
* @param apiHandle The API handle to use.
* @param nativeHandle The OrtEnvironment handle.
* @return An array of OrtEpDevice handles.
* @throws OrtException If the call failed.
*/
private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException;

/**
* Closes the OrtEnvironment, frees the handle.
*
Expand Down
117 changes: 117 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtEpDevice.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.util.Map;

/** A tuple of Execution Provider information and the hardware device. */
public final class OrtEpDevice {

private final long nativeHandle;

private final String epName;
private final String epVendor;
private final Map<String, String> epMetadata;
private final Map<String, String> epOptions;
private final OrtHardwareDevice device;

/**
* Construct an OrtEpDevice tuple from the native pointer.
*
* @param nativeHandle The native pointer.
*/
OrtEpDevice(long nativeHandle) {
this.nativeHandle = nativeHandle;
this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle);
this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle);
String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
this.epMetadata = OrtUtil.convertToMap(metadata);
String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle);
this.epOptions = OrtUtil.convertToMap(options);
this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle));
}

/**
* Return the native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}

/**
* Gets the EP name.
*
* @return The EP name.
*/
public String getName() {
return epName;
}

/**
* Gets the vendor name.
*
* @return The vendor name.
*/
public String getVendor() {
return epVendor;
}

/**
* Gets an unmodifiable view on the EP metadata.
*
* @return The EP metadata.
*/
public Map<String, String> getMetadata() {
return epMetadata;
}

/**
* Gets an unmodifiable view on the EP options.
*
* @return The EP options.
*/
public Map<String, String> getOptions() {
return epOptions;
}

/**
* Gets the device information.
*
* @return The device information.
*/
public OrtHardwareDevice getDevice() {
return device;
}

@Override
public String toString() {
return "OrtEpDevice{"
+ "epName='"
+ epName
+ '\''
+ ", epVendor='"
+ epVendor
+ '\''
+ ", epMetadata="
+ epMetadata
+ ", epOptions="
+ epOptions
+ ", device="
+ device
+ '}';
}

private static native String getName(long apiHandle, long nativeHandle);

private static native String getVendor(long apiHandle, long nativeHandle);

private static native String[][] getMetadata(long apiHandle, long nativeHandle);

private static native String[][] getOptions(long apiHandle, long nativeHandle);

private static native long getDeviceHandle(long apiHandle, long nativeHandle);
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
package ai.onnxruntime;

import java.util.EnumSet;

Expand Down
Loading
Loading