Skip to content

Commit

Permalink
Setting gpu execution providers as features
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Dec 16, 2024
1 parent f363f47 commit 8fe42b2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
- name: Build Aarch64 Darwin Release for ${{ needs.prebuild_preparation.outputs.bin_name }}
working-directory: ./ahnlich
run: |
cargo build --release --target aarch64-apple-darwin --bin ${{ needs.prebuild_preparation.outputs.bin_name }}
cargo build --features coreml --release --target aarch64-apple-darwin --bin ${{ needs.prebuild_preparation.outputs.bin_name }}
tar -cvzf aarch64-darwin-${{ needs.prebuild_preparation.outputs.bin_name }}.tar.gz -C target/aarch64-apple-darwin/release ${{ needs.prebuild_preparation.outputs.bin_name }}
gh release upload ${{github.event.release.tag_name}} aarch64-darwin-${{ needs.prebuild_preparation.outputs.bin_name }}.tar.gz
Expand All @@ -93,7 +93,7 @@ jobs:
- name: Build x86_64 Apple Darwin Release for ${{ needs.prebuild_preparation.outputs.bin_name }}
working-directory: ./ahnlich
run: |
cargo build --release --target x86_64-apple-darwin --bin ${{ needs.prebuild_preparation.outputs.bin_name }}
cargo build --features coreml --release --target x86_64-apple-darwin --bin ${{ needs.prebuild_preparation.outputs.bin_name }}
tar -cvzf x86_64-apple-darwin-${{ needs.prebuild_preparation.outputs.bin_name }}.tar.gz -C target/x86_64-apple-darwin/release ${{ needs.prebuild_preparation.outputs.bin_name }}
gh release upload ${{github.event.release.tag_name}} x86_64-apple-darwin-${{ needs.prebuild_preparation.outputs.bin_name }}.tar.gz
Expand Down
14 changes: 11 additions & 3 deletions ahnlich/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ hf-hub = { version = "0.3", default-features = false }
dirs = "5.0.1"
ort = { version = "=2.0.0-rc.5", features = [
"ndarray",
"tensorrt",
"cuda",
"coreml",
] }
ort-sys = "=2.0.0-rc.8"
moka = { version = "0.12.8", features = ["future"] }
Expand All @@ -53,6 +50,17 @@ futures.workspace = true
tiktoken-rs = "0.5.9"
itertools.workspace = true
tokenizers = { version = "0.20.1", features = ["hf-hub"] }

[features]
# ORT Execution providers
default = ["tensorrt", "cuda"]
tensorrt = ["ort/tensorrt"]
cuda = ["ort/cuda"]
# activate only on apple devices
coreml = ["ort/coreml"]
# activate only on windows devices
directml = ["ort/directml"]

[dev-dependencies]
db = { path = "../db", version = "*" }
pretty_assertions.workspace = true
7 changes: 6 additions & 1 deletion ahnlich/ai/src/engine/ai/providers/ort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use crate::error::AIProxyError;
use fallible_collections::FallibleVec;
use hf_hub::{api::sync::ApiBuilder, Cache};
use itertools::Itertools;
use ort::{CUDAExecutionProvider, CoreMLExecutionProvider, TensorRTExecutionProvider};
use ort::{
CUDAExecutionProvider, CoreMLExecutionProvider, DirectMLExecutionProvider,
TensorRTExecutionProvider,
};
use ort::{Session, SessionOutputs, Value};
use rayon::prelude::*;

Expand Down Expand Up @@ -356,6 +359,8 @@ impl ProviderTrait for ORTProvider {
// Prefer TensorRT over CUDA.
TensorRTExecutionProvider::default().build(),
CUDAExecutionProvider::default().build(),
// Use DirectML on Windows if NVIDIA EPs are not available
DirectMLExecutionProvider::default().build(),
// Or use ANE on Apple platforms
CoreMLExecutionProvider::default().build(),
])
Expand Down

0 comments on commit 8fe42b2

Please sign in to comment.