Skip to content

Commit

Permalink
Merge pull request #964 from cescoffier/fix-onnx-native
Browse files Browse the repository at this point in the history
Fix ONNX Runtime Execution in Native Executable
  • Loading branch information
cescoffier authored Oct 2, 2024
2 parents 0663a72 + ca5e2e6 commit a4d7193
Show file tree
Hide file tree
Showing 39 changed files with 190 additions and 161 deletions.
35 changes: 10 additions & 25 deletions .github/workflows/build-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,21 @@ defaults:
shell: bash

jobs:
# Build the project, no native tests.
build-and-test-jvm:
name: Build on ${{ matrix.os }} - ${{ matrix.java }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
java: [17]
runs-on: ${{ matrix.os }}
name: Main Build
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}

steps:
- name: Prepare git
run: git config --global core.autocrlf false
if: startsWith(matrix.os, 'windows')

- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up JDK ${{ matrix.java }}
- name: Set up JDK 17
uses: actions/setup-java@v4
with:
distribution: temurin
java-version: ${{ matrix.java }}
java-version: 17
cache: 'maven'

- name: Build with Maven
Expand All @@ -65,33 +57,27 @@ jobs:
run: |
cd integration-tests
# skip RAG module as it doesn't have any native-compatible tests now
# skip 'embed' modules (with in-process embeddings) and others that don't work in native
# FIXME: reenable embedding modules, see https://github.com/quarkiverse/quarkus-langchain4j/issues/394
MATRIX='{"testModule":'$( \
find . -mindepth 2 -maxdepth 2 -type f -name 'pom.xml' -exec dirname {} \; \
| sed 's|^\./||' \
| sort -u \
| grep -v rag \
| grep -v embed \
| grep -v jlama \
| jq -R -s -c 'split("\n")[:-1]' \
)'}'
echo "matrix=$MATRIX" >> $GITHUB_OUTPUT
# Test the project with different JDKs.
test-jvm-alt:
name: Test on ${{ matrix.os }} - ${{ matrix.java }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
java: [21]
java: [21, 22, 23]
name: Test on ${{ matrix.os }} - ${{ matrix.java }}
runs-on: ${{ matrix.os }}
steps:
- name: Prepare git
run: git config --global core.autocrlf false
if: startsWith(matrix.os, 'windows')

- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v4
Expand All @@ -116,7 +102,6 @@ jobs:
runs-on: ubuntu-latest

steps:

- uses: actions/checkout@v4

- name: Set up JDK 17
Expand Down
6 changes: 1 addition & 5 deletions .github/workflows/build-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
java: [17, 21]
java: [17, 21, 22, 23]
runs-on: ${{ matrix.os }}
steps:
- name: Prepare git
run: git config --global core.autocrlf false
if: startsWith(matrix.os, 'windows')

- uses: actions/checkout@v3
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.nativeimage.NativeImageResourceBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedClassBuildItem;

/**
* Generate a local embedding build item for each local embedding model available in the classpath.
Expand Down Expand Up @@ -82,6 +83,16 @@ public void generateLocalEmbeddingBuildItems(BuildProducer<InProcessEmbeddingBui
}
}

@BuildStep
void requireOnnxRuntime(List<InProcessEmbeddingBuildItem> embedding, BuildProducer<RequireOnnxRuntimeBuildItem> producer) {
for (InProcessEmbeddingBuildItem item : embedding) {
if (item.requireOnnxRuntime()) {
producer.produce(new RequireOnnxRuntimeBuildItem());
break;
}
}
}

// Expose a bean for each in process embedding model
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
Expand All @@ -108,25 +119,23 @@ void exposeInProcessEmbeddingBeans(InProcessEmbeddingRecorder recorder,
}
}

private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {
if (!NamedConfigUtil.isDefault(configName)) {
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
}
}

@BuildStep
void includeInProcessEmbeddingModelsInNativeExecutable(
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems,
void configureNativeExecutableForInProcessEmbedding(List<InProcessEmbeddingBuildItem> embeddings,
BuildProducer<RuntimeInitializedClassBuildItem> classes,
BuildProducer<NativeImageResourceBuildItem> resources,
BuildProducer<ReflectiveClassBuildItem> reflection) {
for (InProcessEmbeddingBuildItem inProcessEmbeddingBuildItem : inProcessEmbeddingBuildItems) {
for (InProcessEmbeddingBuildItem inProcessEmbeddingBuildItem : embeddings) {
classes.produce(new RuntimeInitializedClassBuildItem(inProcessEmbeddingBuildItem.className()));
resources.produce(new NativeImageResourceBuildItem(inProcessEmbeddingBuildItem.onnxModelPath()));
resources.produce(new NativeImageResourceBuildItem(inProcessEmbeddingBuildItem.vocabularyPath()));
reflection.produce(ReflectiveClassBuildItem.builder(inProcessEmbeddingBuildItem.className())
.constructors(true)
.fields(true)
.methods(true)
.build());
}
}

private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {
if (!NamedConfigUtil.isDefault(configName)) {
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
.constructors().fields().methods().build());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.quarkiverse.langchain4j.deployment;

import java.util.List;
import java.util.stream.Stream;

import io.quarkiverse.langchain4j.deployment.items.InProcessEmbeddingBuildItem;
import io.quarkus.bootstrap.classloading.QuarkusClassLoader;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.Consume;
import io.quarkus.deployment.builditem.nativeimage.JniRuntimeAccessBuildItem;
import io.quarkus.deployment.builditem.nativeimage.NativeImageResourcePatternsBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedPackageBuildItem;

/**
* A processor configuring the native image build for the OnnxRuntime.
* Only enabled if the `RequireOnnxRuntimeBuildItem` build item is present.
*/
public class OnnxRuntimeProcessor {

@BuildStep
@Consume(RequireOnnxRuntimeBuildItem.class)
void onxxRuntimeNative(
BuildProducer<NativeImageResourcePatternsBuildItem> nativePatternProducer,
BuildProducer<ReflectiveClassBuildItem> reflectionProducer,
BuildProducer<JniRuntimeAccessBuildItem> jniProducer) {
List<String> classesInstantiatedFromNative = List.of(
"ai.onnxruntime.TensorInfo",
"ai.onnxruntime.SequenceInfo",
"ai.onnxruntime.MapInfo",
"ai.onnxruntime.OrtException",
"ai.onnxruntime.OnnxSparseTensor");

reflectionProducer.produce(
ReflectiveClassBuildItem.builder(classesInstantiatedFromNative.toArray(new String[0]))
.fields().methods().constructors().build());

jniProducer.produce(
new JniRuntimeAccessBuildItem(true, true, true, classesInstantiatedFromNative.toArray(new String[0])));

// TODO should only select the target architecture's libs
nativePatternProducer
.produce(NativeImageResourcePatternsBuildItem.builder()
.includeGlobs("ai/onnxruntime/native/**", "native/lib/**").build());

reflectionProducer
.produce(ReflectiveClassBuildItem.builder("opennlp.tools.sentdetect.SentenceDetectorFactory").build());
reflectionProducer.produce(
ReflectiveClassBuildItem.builder("ai.onnxruntime.OnnxTensor").methods().fields().constructors().build());
}

@BuildStep
@Consume(RequireOnnxRuntimeBuildItem.class)
void onnxRuntimeClasses(
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems,
BuildProducer<RuntimeInitializedClassBuildItem> classProducer,
BuildProducer<RuntimeInitializedPackageBuildItem> packageProducer) {
Stream.of(
"dev.langchain4j.model.embedding.OnnxBertBiEncoder",
"dev.langchain4j.model.embedding.HuggingFaceTokenizer",
"ai.djl.huggingface.tokenizers.HuggingFaceTokenizer",
"ai.djl.huggingface.tokenizers.jni.TokenizersLibrary",
"ai.djl.huggingface.tokenizers.jni.LibUtils",
"ai.djl.util.Platform",
"ai.onnxruntime.OrtEnvironment",
"ai.onnxruntime.OnnxRuntime",
"ai.onnxruntime.OnnxTensorLike",
"ai.onnxruntime.OrtAllocator",
"ai.onnxruntime.OrtSession$SessionOptions",
"ai.onnxruntime.OrtSession")
.filter(QuarkusClassLoader::isClassPresentAtRuntime)
.map(RuntimeInitializedClassBuildItem::new).forEach(classProducer::produce);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.quarkiverse.langchain4j.deployment;

import io.quarkus.builder.item.SimpleBuildItem;

/**
* A build item that is used to require the OnnxRuntime to be built into the native image.
*/
public final class RequireOnnxRuntimeBuildItem extends SimpleBuildItem {
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,22 @@ public final class InProcessEmbeddingBuildItem extends MultiBuildItem implements
private final String vocabularyPath;

private final String className;
private boolean requireOnnxRuntime;

public InProcessEmbeddingBuildItem(String modelName, String className, String onnxModelPath, String vocabularyPath) {
this.modelName = modelName;
this.className = className;
this.onnxModelPath = onnxModelPath;
this.vocabularyPath = vocabularyPath;
this.requireOnnxRuntime = true;
}

public boolean requireOnnxRuntime() {
return requireOnnxRuntime;
}

public void setRequireOnnxRuntime(boolean requireOnnxRuntime) {
this.requireOnnxRuntime = requireOnnxRuntime;
}

public String modelName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@

<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-integration-tests-parent</artifactId>
<artifactId>quarkus-langchain4j-integration-tests-in-process-embedding-models</artifactId>
<version>999-SNAPSHOT</version>
</parent>

<artifactId>quarkus-langchain4j-integration-test-embed-all-minilm-l6-v2-q</artifactId>
<name>Quarkus LangChain4j - Integration Tests - embeddings-all-minilm-l6-v2-q</name>

<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parsers-base</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core</artifactId>
Expand All @@ -27,12 +22,6 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<version>${langchain4j-embeddings.version}</version>
<exclusions>
<exclusion>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,24 @@

<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-integration-tests-parent</artifactId>
<artifactId>quarkus-langchain4j-integration-tests-in-process-embedding-models</artifactId>
<version>999-SNAPSHOT</version>
</parent>

<artifactId>quarkus-langchain4j-integration-test-embed-all-minilm-l6-v2</artifactId>
<name>Quarkus LangChain4j - Integration Tests - embeddings-all-minilm-l6-v2</name>

<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parsers-base</artifactId>
<version>${project.version}</version>
</dependency>

<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${langchain4j-embeddings.version}</version>
<exclusions>
<exclusion>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public class InProcessEmbeddingResource {

@POST
public String computeEmbedding(String sentence) {

var r1 = allMiniLmL6V2QuantizedEmbeddingModel.embed(sentence);
var r2 = embeddingModel.embed(sentence);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,24 @@

<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-integration-tests-parent</artifactId>
<artifactId>quarkus-langchain4j-integration-tests-in-process-embedding-models</artifactId>
<version>999-SNAPSHOT</version>
</parent>

<artifactId>quarkus-langchain4j-integration-test-embed-bge-small-en-q</artifactId>
<name>Quarkus LangChain4j - Integration Tests - embeddings-bge-small-en-q</name>

<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parsers-base</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-bge-small-en-q</artifactId>
<version>${langchain4j-embeddings.version}</version>
<exclusions>
<exclusion>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
Expand Down
Loading

0 comments on commit a4d7193

Please sign in to comment.