Skip to content

Commit

Permalink
[java] CUDA & TensorRT options fix (microsoft#20549)
Browse files Browse the repository at this point in the history
### Description
I misunderstood how UpdateCUDAProviderOptions and
UpdateTensorRTProviderOptions work in the C API, I had assumed that they
updated the options struct, however they re-initialize the struct to the
defaults then only apply the values in the update. I've rewritten the
Java bindings for those classes so that they aggregate all the updates
and apply them in one go. I also updated the C API documentation to note
that these classes have this behaviour. I've not checked if any of the
other providers with an options struct have this behaviour, we only
expose CUDA and TensorRT's options in Java.

There's a small unrelated update to add a private constructor to the
Fp16Conversions classes to remove a documentation warning (they
shouldn't be instantiated anyway as they are utility classes containing
static methods).

### Motivation and Context
Fixes microsoft#20544.
  • Loading branch information
Craigacp authored May 5, 2024
1 parent baaef59 commit a366920
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 64 deletions.
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2937,7 +2937,7 @@ struct OrtApi {
*
* Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc
* to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2
* and value should be its related range.
* and value should be its related range. Recreates the options and only sets the supplied values.
*
* For example, key="trt_max_workspace_size" and value="2147483648"
*
Expand Down Expand Up @@ -3433,7 +3433,7 @@ struct OrtApi {
*
* Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
* to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2
* and value should be its related range.
* and value should be its related range. Recreates the options and only sets the supplied values.
*
* For example, key="device_id" and value="0"
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
/** * Conversions between fp16, bfloat16 and fp32. */
public final class Fp16Conversions {
private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName());


private Fp16Conversions() {}

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
Expand Down
9 changes: 8 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProviderOptions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -53,6 +53,13 @@ protected static long getApiHandle() {
*/
public abstract OrtProvider getProvider();

/**
* Applies the Java side configuration to the native side object.
*
* @throws OrtException If the native call failed.
*/
protected abstract void applyToNative() throws OrtException;

/**
* Is the native object closed?
*
Expand Down
6 changes: 5 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -1022,6 +1022,8 @@ public void addCUDA(int deviceNum) throws OrtException {
public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException {
checkClosed();
if (OnnxRuntime.extractCUDA()) {
// Cast is to make the compiler pick the right overload.
((OrtProviderOptions) cudaOpts).applyToNative();
addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle);
} else {
throw new OrtException(
Expand Down Expand Up @@ -1125,6 +1127,8 @@ public void addTensorrt(int deviceNum) throws OrtException {
public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtException {
checkClosed();
if (OnnxRuntime.extractTensorRT()) {
// Cast is to make the compiler pick the right overload.
((OrtProviderOptions) tensorRTOpts).applyToNative();
addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle);
} else {
throw new OrtException(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -41,7 +41,6 @@ public OrtCUDAProviderOptions(int deviceId) throws OrtException {

String id = "" + deviceId;
this.options.put("device_id", id);
add(getApiHandle(), this.nativeHandle, "device_id", id);
}

@Override
Expand All @@ -59,17 +58,17 @@ public OrtProvider getProvider() {
private static native long create(long apiHandle) throws OrtException;

/**
* Adds an option to this options instance.
* Adds the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param keys The option keys.
* @param values The option values.
* @throws OrtException If the addition failed.
*/
@Override
protected native void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected native void applyToNative(
long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException;

/**
* Closes this options instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -41,7 +41,6 @@ public OrtTensorRTProviderOptions(int deviceId) throws OrtException {

String id = "" + deviceId;
this.options.put("device_id", id);
add(getApiHandle(), this.nativeHandle, "device_id", id);
}

@Override
Expand All @@ -59,17 +58,17 @@ public OrtProvider getProvider() {
private static native long create(long apiHandle) throws OrtException;

/**
* Adds an option to this options instance.
* Adds the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param keys The option keys.
* @param values The option values.
* @throws OrtException If the addition failed.
*/
@Override
protected native void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected native void applyToNative(
long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException;

/**
* Closes this options instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -36,7 +36,6 @@ public void add(String key, String value) throws OrtException {
Objects.requireNonNull(key, "Key must not be null");
Objects.requireNonNull(value, "Value must not be null");
options.put(key, value);
add(getApiHandle(), nativeHandle, key, value);
}

/**
Expand All @@ -49,7 +48,7 @@ public void add(String key, String value) throws OrtException {
public void parseOptionsString(String serializedForm) throws OrtException {
String[] options = serializedForm.split(";");
for (String o : options) {
if (!o.isEmpty() && o.contains("=")) {
if (o.contains("=")) {
String[] curOption = o.split("=");
if ((curOption.length == 2) && !curOption[0].isEmpty() && !curOption[1].isEmpty()) {
add(curOption[0], curOption[1]);
Expand All @@ -76,15 +75,31 @@ public String getOptionsString() {
.collect(Collectors.joining(";", "", ";"));
}

@Override
protected void applyToNative() throws OrtException {
if (!options.isEmpty()) {
String[] keys = new String[options.size()];
String[] values = new String[options.size()];
int i = 0;
for (Map.Entry<String, String> e : options.entrySet()) {
keys[i] = e.getKey();
values[i] = e.getValue();
i++;
}

applyToNative(getApiHandle(), this.nativeHandle, keys, values);
}
}

/**
* Adds an option to this options instance.
* Add all the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param key The option keys.
* @param value The option values.
* @throws OrtException If the addition failed.
*/
protected abstract void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected abstract void applyToNative(
long apiHandle, long nativeHandle, String[] key, String[] value) throws OrtException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public final class Fp16Conversions {
fp32ToFp16 = tmp32;
}

private Fp16Conversions() {}

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand All @@ -24,19 +24,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_cre

/*
* Class: ai_onnxruntime_providers_OrtCUDAProviderOptions
* Method: add
* Signature: (JJLjava/lang/String;Ljava/lang/String;)V
* Method: applyToNative
* Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_add
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) {
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_applyToNative
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle;
const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, &keyStr, &valueStr, 1));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr);

jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr);
const char** keys = (const char**) allocarray(keyLength, sizeof(const char*));
const char** values = (const char**) allocarray(keyLength, sizeof(const char*));
if ((keys == NULL) || (values == NULL)) {
if (keys != NULL) {
free((void*)keys);
}
if (values != NULL) {
free((void*)values);
}
throwOrtException(jniEnv, 1, "Not enough memory");
} else {
// Copy out strings into UTF-8.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i);
values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
}
// Write to the provider options.
checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, keys, values, keyLength));
// Release allocated strings.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]);
}
free((void*)keys);
free((void*)values);
}
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand All @@ -23,19 +23,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions

/*
* Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions
* Method: add
* Signature: (JJLjava/lang/String;Ljava/lang/String;)V
* Method: applyToNative
* Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_add
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) {
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_applyToNative
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle;
const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, &keyStr, &valueStr, 1));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr);

jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr);
const char** keys = (const char**) allocarray(keyLength, sizeof(const char*));
const char** values = (const char**) allocarray(keyLength, sizeof(const char*));
if ((keys == NULL) || (values == NULL)) {
if (keys != NULL) {
free((void*)keys);
}
if (values != NULL) {
free((void*)values);
}
throwOrtException(jniEnv, 1, "Not enough memory");
} else {
// Copy out strings into UTF-8.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i);
values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
}
// Write to the provider options.
checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength));
// Release allocated strings.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]);
}
free((void*)keys);
free((void*)values);
}
}

/*
Expand Down
5 changes: 4 additions & 1 deletion java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -678,6 +678,9 @@ private void runProvider(OrtProvider provider) throws OrtException {
if (provider == OrtProvider.CORE_ML) {
// CoreML gives slightly different answers on a 2020 13" M1 MBP
assertArrayEquals(expectedOutput, resultArray, 1e-2f);
} else if (provider == OrtProvider.CUDA) {
// CUDA gives slightly different answers on a H100 with CUDA 12.2
assertArrayEquals(expectedOutput, resultArray, 1e-3f);
} else {
assertArrayEquals(expectedOutput, resultArray, 1e-5f);
}
Expand Down
Loading

0 comments on commit a366920

Please sign in to comment.