Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/onnxruntime_java.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ if (WIN32)
if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime>)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_FILE_NAME:onnxruntime4j_jni>)
if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB))
if (TARGET onnxruntime_providers_shared)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_FILE_NAME:onnxruntime_providers_shared>)
endif()
if (onnxruntime_USE_CUDA)
Expand Down Expand Up @@ -205,7 +205,7 @@ if (WIN32)
else()
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime>)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${JAVA_PACKAGE_JNI_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime4j_jni>)
if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB))
if (TARGET onnxruntime_providers_shared)
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime_providers_shared> ${JAVA_PACKAGE_LIB_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime_providers_shared>)
endif()
if (onnxruntime_USE_CUDA)
Expand Down
4 changes: 4 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,10 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR})
endif()
if (WIN32)
set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $<IF:$<BOOL:${WIN32}>,$<TARGET_FILE_NAME:example_plugin_ep>,$<TARGET_LINKER_FILE_NAME:example_plugin_ep>>)
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:example_plugin_ep> ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME})
endif()

# delegate to gradle's test runner

Expand Down
18 changes: 16 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ final class OnnxRuntime {
private static final int ORT_API_VERSION_13 = 13;
// Post 1.13 builds of the ORT API
private static final int ORT_API_VERSION_14 = 14;
// Post 1.22 builds of the ORT API
private static final int ORT_API_VERSION_23 = 23;

// The initial release of the ORT training API.
private static final int ORT_TRAINING_API_VERSION_1 = 1;
Expand Down Expand Up @@ -103,6 +105,9 @@ final class OnnxRuntime {
/** The Training API handle. */
static long ortTrainingApiHandle;

/** The Compile API handle. */
static long ortCompileApiHandle;

/** Is training enabled in the native library */
static boolean trainingEnabled;

Expand Down Expand Up @@ -174,12 +179,13 @@ static synchronized void init() throws IOException {
}
load(ONNXRUNTIME_JNI_LIBRARY_NAME);

ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23);
if (ortApiHandle == 0L) {
throw new IllegalStateException(
"There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded");
}
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14);
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23);
ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle);
trainingEnabled = ortTrainingApiHandle != 0L;
providers = initialiseProviders(ortApiHandle);
version = initialiseVersion();
Expand Down Expand Up @@ -497,6 +503,14 @@ private static EnumSet<OrtProvider> initialiseProviders(long ortApiHandle) {
*/
private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber);

/**
* Get a reference to the compile API struct.
*
* @param apiHandle The ORT API struct pointer.
* @return A pointer to the compile API struct.
*/
private static native long initialiseCompileAPIBase(long apiHandle);

/**
* Gets the array of available providers.
*
Expand Down
82 changes: 81 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand All @@ -8,7 +8,11 @@
import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;

Expand Down Expand Up @@ -442,6 +446,48 @@ public static EnumSet<OrtProvider> getAvailableProviders() {
return OnnxRuntime.providers.clone();
}

/**
* Registers an execution provider library with this OrtEnvironment.
*
* @param registrationName The name to register the library with (used to remove it later with
* {@link #unregisterExecutionProviderLibrary(String)}).
* @param libraryPath The path to the library binary on disk.
* @throws OrtException If the library could not be registered.
*/
public void registerExecutionProviderLibrary(String registrationName, String libraryPath)
throws OrtException {
registerExecutionProviderLibrary(
OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath);
}

/**
* Unregisters an execution provider library from this OrtEnvironment.
*
* @param registrationName The name the library was registered under.
* @throws OrtException If the library could not be removed.
*/
public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException {
unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName);
}

/**
* Get the list of all execution provider and device combinations that are available.
*
* @see OrtSession.SessionOptions#addExecutionProvider(List, Map)
* @return The list of execution provider and device combinations.
* @throws OrtException If the devices could not be listed.
*/
public List<OrtEpDevice> getEpDevices() throws OrtException {
long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle);

List<OrtEpDevice> devicesList = new ArrayList<>();
for (long deviceHandle : deviceHandles) {
devicesList.add(new OrtEpDevice(deviceHandle));
}

return Collections.unmodifiableList(devicesList);
}

/**
* Creates the native object.
*
Expand Down Expand Up @@ -476,6 +522,40 @@ private static native long createHandle(
*/
private static native long getDefaultAllocator(long apiHandle) throws OrtException;

/**
* Registers the specified execution provider with this OrtEnvironment.
*
* @param apiHandle The API handle.
* @param nativeHandle The OrtEnvironment handle.
* @param registrationName The name of the execution provider.
* @param libraryPath The path to the execution provider binary.
* @throws OrtException If the registration failed.
*/
private static native void registerExecutionProviderLibrary(
long apiHandle, long nativeHandle, String registrationName, String libraryPath)
throws OrtException;

/**
* Removes the specified execution provider from this OrtEnvironment.
*
* @param apiHandle The API handle.
* @param nativeHandle The OrtEnvironment handle.
* @param registrationName The name of the execution provider.
* @throws OrtException If the removal failed.
*/
private static native void unregisterExecutionProviderLibrary(
long apiHandle, long nativeHandle, String registrationName) throws OrtException;

/**
* Gets handles for the EP device tuples available in this OrtEnvironment.
*
* @param apiHandle The API handle to use.
* @param nativeHandle The OrtEnvironment handle.
* @return An array of OrtEpDevice handles.
* @throws OrtException If the call failed.
*/
private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException;

/**
* Closes the OrtEnvironment, frees the handle.
*
Expand Down
117 changes: 117 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtEpDevice.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.util.Map;

/** A tuple of Execution Provider information and the hardware device. */
public final class OrtEpDevice {

private final long nativeHandle;

private final String epName;
private final String epVendor;
private final Map<String, String> epMetadata;
private final Map<String, String> epOptions;
private final OrtHardwareDevice device;

/**
* Construct an OrtEpDevice tuple from the native pointer.
*
* @param nativeHandle The native pointer.
*/
OrtEpDevice(long nativeHandle) {
this.nativeHandle = nativeHandle;
this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle);
this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle);
String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
this.epMetadata = OrtUtil.convertToMap(metadata);
String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle);
this.epOptions = OrtUtil.convertToMap(options);
this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle));
}

/**
* Return the native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}

/**
* Gets the EP name.
*
* @return The EP name.
*/
public String getName() {
return epName;
}

/**
* Gets the vendor name.
*
* @return The vendor name.
*/
public String getVendor() {
return epVendor;
}

/**
* Gets an unmodifiable view on the EP metadata.
*
* @return The EP metadata.
*/
public Map<String, String> getMetadata() {
return epMetadata;
}

/**
* Gets an unmodifiable view on the EP options.
*
* @return The EP options.
*/
public Map<String, String> getOptions() {
return epOptions;
}

/**
* Gets the device information.
*
* @return The device information.
*/
public OrtHardwareDevice getDevice() {
return device;
}

@Override
public String toString() {
return "OrtEpDevice{"
+ "epName='"
+ epName
+ '\''
+ ", epVendor='"
+ epVendor
+ '\''
+ ", epMetadata="
+ epMetadata
+ ", epOptions="
+ epOptions
+ ", device="
+ device
+ '}';
}

private static native String getName(long apiHandle, long nativeHandle);

private static native String getVendor(long apiHandle, long nativeHandle);

private static native String[][] getMetadata(long apiHandle, long nativeHandle);

private static native String[][] getOptions(long apiHandle, long nativeHandle);

private static native long getDeviceHandle(long apiHandle, long nativeHandle);
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
package ai.onnxruntime;

import java.util.EnumSet;

Expand Down
Loading
Loading