Skip to content

Commit

Permalink
Adding various execution providers
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Dec 16, 2024
1 parent 238c3d2 commit f24e419
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
5 changes: 5 additions & 0 deletions ahnlich/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[target.aarch64-apple-darwin]
rustflags = ["-Clink-arg=-fapple-link-rtlib"]

[target.x86_64-apple-darwin]
rustflags = ["-Clink-arg=-fapple-link-rtlib"]
1 change: 1 addition & 0 deletions ahnlich/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion ahnlich/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@ fallible_collections.workspace = true
rayon.workspace = true
hf-hub = { version = "0.3", default-features = false }
dirs = "5.0.1"
ort = { version = "=2.0.0-rc.5", features = ["ndarray"] }
ort = { version = "=2.0.0-rc.5", features = [
"ndarray",
"directml",
"tensorrt",
"cuda",
"coreml",
] }
ort-sys = "=2.0.0-rc.8"
moka = { version = "0.12.8", features = ["future"] }
tracing-opentelemetry.workspace = true
futures.workspace = true
Expand Down
16 changes: 16 additions & 0 deletions ahnlich/ai/src/engine/ai/providers/ort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use crate::error::AIProxyError;
use fallible_collections::FallibleVec;
use hf_hub::{api::sync::ApiBuilder, Cache};
use itertools::Itertools;
use ort::{
CUDAExecutionProvider, CoreMLExecutionProvider, DirectMLExecutionProvider,
TensorRTExecutionProvider,
};
use ort::{Session, SessionOutputs, Value};
use rayon::prelude::*;

Expand Down Expand Up @@ -350,6 +354,18 @@ impl ProviderTrait for ORTProvider {
}

fn load_model(&mut self) -> Result<(), AIProxyError> {
ort::init()
.with_execution_providers([
// Prefer TensorRT over CUDA.
TensorRTExecutionProvider::default().build(),
CUDAExecutionProvider::default().build(),
// Use DirectML on Windows if NVIDIA EPs are not available
DirectMLExecutionProvider::default().build(),
// Or use ANE on Apple platforms
CoreMLExecutionProvider::default().build(),
])
.commit()?;

let Some(cache_location) = self.cache_location.clone() else {
return Err(AIProxyError::CacheLocationNotInitiailized);
};
Expand Down

0 comments on commit f24e419

Please sign in to comment.