From 262b6bd3b7531503f40f2cb6059d22d7d9d84f27 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 2 Apr 2024 00:58:50 -0400 Subject: [PATCH] [java][DML EP] Modifying dml_provider_factory.h so it can compile as a C header file (#20157) ### Description The dml_provider_factory header file can't be used in C programs as it defines C++ inline operators. This PR rearranges that header file so that it looks like valid C when used from C, and also makes a couple of small modifications to the Java code so it correctly binds to the DML EP at build time. I'm having some difficulty testing it as I think it's pulling in the old version of DirectML on my computer and I can't figure out what the library loading path is in Java to make it look at the recent version I downloaded. So the test I added fails with: ``` InferenceTest > testDirectML() FAILED ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Exception during initialization: \onnxruntime\core\providers\dml\DmlExecutionProvider\src\AbiCustomRegistry.cpp(518)\onnxruntime.dll!00007FFF74819333: (caller: 00007FFF74793509) Exception(3) tid(4f58) 80070057 The parameter is incorrect. at app//ai.onnxruntime.OrtSession.createSession(Native Method) at app//ai.onnxruntime.OrtSession.(OrtSession.java:74) at app//ai.onnxruntime.OrtEnvironment.createSession(OrtEnvironment.java:236) at app//ai.onnxruntime.OrtEnvironment.createSession(OrtEnvironment.java:221) at app//ai.onnxruntime.InferenceTest.openSessionSqueezeNet(InferenceTest.java:1961) at app//ai.onnxruntime.InferenceTest.runProvider(InferenceTest.java:665) at app//ai.onnxruntime.InferenceTest.testDirectML(InferenceTest.java:657) ``` But it does correctly compile, and this error seems very similar to other issues with the DML provider when it doesn't like a model due to the loaded library being old. The test is using the squeezenet file that's been in the repo since 2019. If someone can help me figure out how to get the right version of DML in the library path I can test it more on my end. I tried adding the folder with the new version into the system path, but I'm not very familiar with Windows' library loading behaviour. ### Motivation and Context Fixes #19656 to allow use of the DirectML EP from ORT Java. cc @martinb35 --- .../core/providers/dml/dml_provider_factory.h | 30 ++++++++++++++----- java/build.gradle | 2 +- ...ai_onnxruntime_OrtSession_SessionOptions.c | 2 +- .../java/ai/onnxruntime/InferenceTest.java | 8 +++++ 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 7d7f05193f48..33b98edf3bf4 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -27,14 +27,8 @@ typedef struct IDMLDevice IDMLDevice; #include "onnxruntime_c_api.h" #ifdef __cplusplus -extern "C" { -#endif -enum OrtDmlPerformancePreference { - Default = 0, - HighPerformance = 1, - MinimumPower = 2 -}; +extern "C" { enum OrtDmlDeviceFilter : uint32_t { #ifdef ENABLE_NPU_ADAPTER_ENUMERATION @@ -54,11 +48,33 @@ inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } +#else + +typedef enum OrtDmlDeviceFilter { +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +#else + Gpu = 1 << 0, +#endif +} OrtDmlDeviceFilter; + +#endif + +typedef enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +} OrtDmlPerformancePreference; + struct OrtDmlDeviceOptions { OrtDmlPerformancePreference Preference; OrtDmlDeviceFilter Filter; }; +typedef struct OrtDmlDeviceOptions OrtDmlDeviceOptions; + /** * [[deprecated]] * This export is deprecated. diff --git a/java/build.gradle b/java/build.gradle index 5a0c4a9e3937..fd66ec220b78 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -185,7 +185,7 @@ test { if (cmakeBuildDir != null) { workingDir cmakeBuildDir } - systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS']) + systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS']) testLogging { events "passed", "skipped", "failed" showStandardStreams = true diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 4a5e2b7ef3b1..337f4c1921c6 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -630,7 +630,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIG JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) { (void)jobj; - #ifdef USE_DIRECTML + #ifdef USE_DML checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID)); #else (void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined. diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 9925197e4507..ac65cbab146b 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -651,6 +651,12 @@ public void testCoreML() throws OrtException { runProvider(OrtProvider.CORE_ML); } + @Test + @EnabledIfSystemProperty(named = "USE_DML", matches = "1") + public void testDirectML() throws OrtException { + runProvider(OrtProvider.DIRECT_ML); + } + private void runProvider(OrtProvider provider) throws OrtException { EnumSet providers = OrtEnvironment.getAvailableProviders(); assertTrue(providers.size() > 1); @@ -1926,6 +1932,8 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid options.addNnapi(); break; case DIRECT_ML: + options.setMemoryPatternOptimization(false); + options.setExecutionMode(ExecutionMode.SEQUENTIAL); options.addDirectML(0); break; case ACL: