Skip to content

Commit

Permalink
upgrade ort to v2.0.0-rc.9 (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
TD-Sky authored Dec 3, 2024
1 parent 57db14c commit 2785b09
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 58 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
[dependencies]
clap = { version = "4.2.4", features = ["derive"] }
ndarray = { version = "0.16.1", features = ["rayon"] }
ort = { version = "2.0.0-rc.5", default-features = false}
ort = { version = "2.0.0-rc.9", default-features = false }
anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" }
rand = { version = "0.8.5" }
Expand All @@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
ab_glyph = "0.2.23"
geo = "0.28.0"
prost = "0.12.4"
fast_image_resize = { version = "4.2.1", features = ["image"]}
fast_image_resize = { version = "4.2.1", features = ["image"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tempfile = "3.12.0"
Expand All @@ -50,7 +50,6 @@ default = [
"ort/cuda",
"ort/tensorrt",
"ort/coreml",
"ort/operator-libraries"
]
auto = ["ort/download-binaries"]

Expand Down
113 changes: 58 additions & 55 deletions src/core/ort_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use anyhow::Result;
use half::f16;
use ndarray::{Array, IxDyn};
use ort::{
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
session::{builder::SessionBuilder, Session},
tensor::TensorElementType,
};
use prost::Message;
use std::collections::HashSet;
Expand Down Expand Up @@ -88,38 +90,38 @@ impl OrtEngine {

// build
ort::init().commit()?;
let builder = Session::builder()?;
let mut builder = Session::builder()?;
let mut device = config.device.to_owned();
match device {
Device::Trt(device_id) => {
Self::build_trt(
&inputs_attrs.names,
&inputs_minoptmax,
&builder,
&mut builder,
device_id,
config.trt_int8_enable,
config.trt_fp16_enable,
config.trt_engine_cache_enable,
)?;
}
Device::Cuda(device_id) => {
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0);
})
}
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0);
}),
Device::Cpu(_) => {
Self::build_cpu(&builder)?;
Self::build_cpu(&mut builder)?;
}
_ => todo!(),
}

let session = builder
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(&config.onnx_path)?;

// summary
Expand Down Expand Up @@ -149,7 +151,7 @@ impl OrtEngine {
fn build_trt(
names: &[String],
inputs_minoptmax: &[Vec<MinOptMax>],
builder: &SessionBuilder,
builder: &mut SessionBuilder,
device_id: usize,
int8_enable: bool,
fp16_enable: bool,
Expand Down Expand Up @@ -205,26 +207,27 @@ impl OrtEngine {
}
}

fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
let ep = ort::execution_providers::CUDAExecutionProvider::default()
.with_device_id(device_id as i32);
if ep.is_available()? && ep.register(builder).is_ok() {
Ok(())
} else {
anyhow::bail!("{CROSS_MARK} CUDA initialization failed")
}
}

fn build_coreml(builder: &SessionBuilder) -> Result<()> {
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
if ep.is_available()? && ep.register(builder).is_ok() {
Ok(())
} else {
anyhow::bail!("{CROSS_MARK} CoreML initialization failed")
}
}

fn build_cpu(builder: &SessionBuilder) -> Result<()> {
let ep = ort::CPUExecutionProvider::default();
fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::execution_providers::CPUExecutionProvider::default();
if ep.is_available()? && ep.register(builder).is_ok() {
Ok(())
} else {
Expand Down Expand Up @@ -292,28 +295,28 @@ impl OrtEngine {
let t_pre = std::time::Instant::now();
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
let x_ = match &idtype {
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float16 => {
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
}
TensorElementType::Int32 => {
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
}
TensorElementType::Int64 => {
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
}
TensorElementType::Uint8 => {
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
}
TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
}
TensorElementType::Bool => {
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
}
_ => todo!(),
};
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
}
let t_pre = t_pre.elapsed();
self.ts.add_or_push(0, t_pre);
Expand Down Expand Up @@ -451,45 +454,45 @@ impl OrtEngine {
}

#[allow(dead_code)]
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
match x {
ort::TensorElementType::Float64
| ort::TensorElementType::Uint64
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
ort::TensorElementType::Float32
| ort::TensorElementType::Uint32
| ort::TensorElementType::Int32
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
ort::TensorElementType::Float16
| ort::TensorElementType::Bfloat16
| ort::TensorElementType::Int16
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
ort::TensorElementType::Uint8
| ort::TensorElementType::Int8
| ort::TensorElementType::Bool => 1, // u8, i8, bool
ort::tensor::TensorElementType::Float64
| ort::tensor::TensorElementType::Uint64
| ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
ort::tensor::TensorElementType::Float32
| ort::tensor::TensorElementType::Uint32
| ort::tensor::TensorElementType::Int32
| ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
ort::tensor::TensorElementType::Float16
| ort::tensor::TensorElementType::Bfloat16
| ort::tensor::TensorElementType::Int16
| ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
ort::tensor::TensorElementType::Uint8
| ort::tensor::TensorElementType::Int8
| ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
}
}

#[allow(dead_code)]
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
match value {
0 => None,
1 => Some(ort::TensorElementType::Float32),
2 => Some(ort::TensorElementType::Uint8),
3 => Some(ort::TensorElementType::Int8),
4 => Some(ort::TensorElementType::Uint16),
5 => Some(ort::TensorElementType::Int16),
6 => Some(ort::TensorElementType::Int32),
7 => Some(ort::TensorElementType::Int64),
8 => Some(ort::TensorElementType::String),
9 => Some(ort::TensorElementType::Bool),
10 => Some(ort::TensorElementType::Float16),
11 => Some(ort::TensorElementType::Float64),
12 => Some(ort::TensorElementType::Uint32),
13 => Some(ort::TensorElementType::Uint64),
1 => Some(ort::tensor::TensorElementType::Float32),
2 => Some(ort::tensor::TensorElementType::Uint8),
3 => Some(ort::tensor::TensorElementType::Int8),
4 => Some(ort::tensor::TensorElementType::Uint16),
5 => Some(ort::tensor::TensorElementType::Int16),
6 => Some(ort::tensor::TensorElementType::Int32),
7 => Some(ort::tensor::TensorElementType::Int64),
8 => Some(ort::tensor::TensorElementType::String),
9 => Some(ort::tensor::TensorElementType::Bool),
10 => Some(ort::tensor::TensorElementType::Float16),
11 => Some(ort::tensor::TensorElementType::Float64),
12 => Some(ort::tensor::TensorElementType::Uint32),
13 => Some(ort::tensor::TensorElementType::Uint64),
14 => None, // COMPLEX64
15 => None, // COMPLEX128
16 => Some(ort::TensorElementType::Bfloat16),
16 => Some(ort::tensor::TensorElementType::Bfloat16),
_ => None,
}
}
Expand All @@ -499,7 +502,7 @@ impl OrtEngine {
value_info: &[onnx::ValueInfoProto],
) -> Result<OrtTensorAttr> {
let mut dimss: Vec<Vec<usize>> = Vec::new();
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
let mut names: Vec<String> = Vec::new();
for v in value_info.iter() {
if initializer_names.contains(v.name.as_str()) {
Expand Down Expand Up @@ -569,7 +572,7 @@ impl OrtEngine {
&self.outputs_attrs.names
}

pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.outputs_attrs.dtypes
}

Expand All @@ -585,7 +588,7 @@ impl OrtEngine {
&self.inputs_attrs.names
}

pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.inputs_attrs.dtypes
}

Expand Down

0 comments on commit 2785b09

Please sign in to comment.