Skip to content

Commit

Permalink
add basic ort example
Browse files Browse the repository at this point in the history
  • Loading branch information
krajewskiML committed Jul 28, 2023
1 parent 74bf977 commit 4308028
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 81 deletions.
2 changes: 1 addition & 1 deletion bert/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ xayn-test-utils = { path = "../test-utils" }
criterion = { workspace = true }
csv = { workspace = true }
indicatif = "0.17.5"
onnxruntime = {version = ">=0.0.2", features = ["model-fetching"] }
# onnxruntime = {version = ">=0.0.2", features = ["model-fetching"] }
ort = {version = "1.15.2", features = ["download-binaries"] }
rand = "0.8.5"

Expand Down
136 changes: 56 additions & 80 deletions bert/examples/ort_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@ use std::error;
use std::io::{self, Write};

use ndarray::{array, concatenate, s, Array1, Axis, CowArray};
use ort::{download::language::machine_comprehension::GPT2, tensor::OrtOwnedTensor, Environment, ExecutionProvider, GraphOptimizationLevel, OrtResult, SessionBuilder, Value, LoggingLevel};
use rand::Rng;
use ndarray::{Array2, ArrayBase, Ix2, OwnedRepr};
use ort::{
download::language::machine_comprehension::GPT2, tensor::OrtOwnedTensor, Environment,
ExecutionProvider, GraphOptimizationLevel, LoggingLevel, OrtResult, Session, SessionBuilder,
Value,
};
use rand::Rng;
use serde::de::Error;
use xayn_ai_bert::{tokenizer::Tokenizer, Config, FirstPooler};
use xayn_test_utils::asset::{smbert, sts_data, zdf_data};
use tokenizers::{
tokenizer::Tokenizer as HfTokenizer, PaddingDirection, PaddingParams, PaddingStrategy,
TruncationDirection, TruncationParams, TruncationStrategy,
};

use xayn_ai_bert::tokenizer::Encoding;
use xayn_ai_bert::{tokenizer::Tokenizer, Config, FirstPooler};
use xayn_test_utils::asset::{smbert, sts_data, zdf_data};

fn create_tokenizer(token_size: usize) -> Result<HfTokenizer, Box<dyn std::error::Error>> {
let config = Config::new(smbert()?)?.with_token_size(token_size)?;
Expand Down Expand Up @@ -52,87 +56,59 @@ fn main() -> Result<(), Box<dyn error::Error>> {
.build()?
.into_arc();

let mut session = SessionBuilder::new(&environment)?
.with_model_from_file("/Users/maciejkrajewski/CLionProjects/xayn_discovery_engine/assets/smbert_v0004/model.onnx")?;
let mut session = SessionBuilder::new(&environment)?.with_model_from_file(
"/Users/maciejkrajewski/CLionProjects/xayn_discovery_engine/assets/smbert_v0004/model.onnx",
)?;

let text = "This is a test sentence";
let tokenizer_output = tokenizer.encode(text, false).unwrap();

let tokens = tokenizer_output.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
let mut tokens = CowArray::from(Array1::from_iter(tokens.iter().cloned()));
let array_tokens = tokens.clone().insert_axis(Axis(0)).into_shape((1, tokens.shape()[0])).unwrap().into_dyn();
let token_ids = Value::from_array(session.allocator(), &array_tokens)?;
let tokens = tokenizer_output
.get_ids()
.iter()
.map(|i| *i as i64)
.collect::<Vec<_>>();
let tokens = CowArray::from(Array1::from_iter(tokens.iter().cloned()));
let array_tokens = tokens
.clone()
.insert_axis(Axis(0))
.into_shape((1, tokens.shape()[0]))
.unwrap()
.into_dyn();

let attention_mask = tokenizer_output.get_attention_mask().iter().map(|i| *i as i64).collect::<Vec<_>>();
let mut attention_mask = CowArray::from(Array1::from_iter(attention_mask.iter().cloned()));
let array_attention = attention_mask.clone().insert_axis(Axis(0)).into_shape((1, attention_mask.shape()[0])).unwrap().into_dyn();
let attention_mask = Value::from_array(session.allocator(), &array_attention)?;
let attention_mask = tokenizer_output
.get_attention_mask()
.iter()
.map(|i| *i as i64)
.collect::<Vec<_>>();
let attention_mask = CowArray::from(Array1::from_iter(attention_mask.iter().cloned()));
let array_attention = attention_mask
.clone()
.insert_axis(Axis(0))
.into_shape((1, attention_mask.shape()[0]))
.unwrap()
.into_dyn();

let token_type_ids = tokenizer_output.get_type_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
let mut token_type_ids = CowArray::from(Array1::from_iter(token_type_ids.iter().cloned()));
let array_types = token_type_ids.clone().insert_axis(Axis(0)).into_shape((1, token_type_ids.shape()[0])).unwrap().into_dyn();
let token_type_ids = Value::from_array(session.allocator(), &array_types)?;
let token_type_ids = tokenizer_output
.get_type_ids()
.iter()
.map(|i| *i as i64)
.collect::<Vec<_>>();
let token_type_ids = CowArray::from(Array1::from_iter(token_type_ids.iter().cloned()));
let array_types = token_type_ids
.clone()
.insert_axis(Axis(0))
.into_shape((1, token_type_ids.shape()[0]))
.unwrap()
.into_dyn();

let in_tensor = vec![token_ids, attention_mask, token_type_ids];
let in_tensor = vec![
Value::from_array(session.allocator(), &array_tokens)?,
Value::from_array(session.allocator(), &array_attention)?,
Value::from_array(session.allocator(), &array_types)?,
];
println!("before {:?}", in_tensor);
let outputs = session.run(in_tensor)?;
println!("after {:?}", outputs);
Ok(())
}

// fn main2() -> OrtResult<()> {
// const PROMPT: &str = "The corsac fox (Vulpes corsac), also known simply as a corsac, is a medium-sized fox found in";
// const GEN_TOKENS: i32 = 90;
// const TOP_K: usize = 5;
//
// let mut stdout = io::stdout();
// let mut rng = rand::thread_rng();
//
// let environment = Environment::builder()
// .with_name("GPT-2")
// .with_execution_providers([ExecutionProvider::CUDA(Default::default())])
// .build()?
// .into_arc();
//
// // let session = SessionBuilder::new(&environment)?
// // .with_optimization_level(GraphOptimizationLevel::Level1)?
// // .with_intra_threads(1)?
// // .with_model_downloaded(GPT2::GPT2LmHead)?;
//
// let tokenizer = tokenizers::Tokenizer::from_file("tests/data/gpt2-tokenizer.json").unwrap();
// let tokens = tokenizer.encode(PROMPT, false).unwrap();
// let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
//
// let mut tokens = CowArray::from(Array1::from_iter(tokens.iter().cloned()));
//
// print!("{PROMPT}");
// stdout.flush().unwrap();
//
// for _ in 0..GEN_TOKENS {
// let n_tokens = tokens.shape()[0];
// let array = tokens.clone().insert_axis(Axis(0)).into_shape((1, 1, n_tokens)).unwrap().into_dyn();
// let inputs = vec![Value::from_array(session.allocator(), &array)?];
// let outputs: Vec<Value> = session.run(inputs)?;
// let generated_tokens: OrtOwnedTensor<f32, _> = outputs[0].try_extract()?;
// let generated_tokens = generated_tokens.view();
//
// let probabilities = &mut generated_tokens
// .slice(s![0, 0, -1, ..])
// .insert_axis(Axis(0))
// .to_owned()
// .iter()
// .cloned()
// .enumerate()
// .collect::<Vec<_>>();
// probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
//
// let token = probabilities[rng.gen_range(0..=TOP_K)].0;
// tokens = CowArray::from(concatenate![Axis(0), tokens, array![token.try_into().unwrap()]]);
//
// let token_str = tokenizer.decode(vec![token as _], true).unwrap();
// print!("{}", token_str);
// stdout.flush().unwrap();
// }
//
// println!();
//
// Ok(())
// }
}

0 comments on commit 4308028

Please sign in to comment.