-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #130 from edgenai/feat/issue120
Feat/issue120: backend-agnostic endpoints
- Loading branch information
Showing
15 changed files
with
1,074 additions
and
349 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[package] | ||
name = "edgen_rt_chat_faker" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
blake3 = { workspace = true } | ||
dashmap = { workspace = true } | ||
derive_more = { workspace = true } | ||
edgen_core = { path = "../edgen_core" } | ||
futures = { workspace = true } | ||
thiserror = { workspace = true } | ||
tokio = { workspace = true, features = ["sync", "rt", "fs"] } | ||
tracing = { workspace = true } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* Copyright 2023- The Binedge, Lda team. All rights reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
//! A fake model RT for chat completions that answers with predefined strings | ||
|
||
use std::path::Path; | ||
use std::sync::Arc; | ||
|
||
use dashmap::DashMap; | ||
use futures::Stream; | ||
use tracing::info; | ||
|
||
use edgen_core::llm::{CompletionArgs, LLMEndpoint, LLMEndpointError}; | ||
use edgen_core::BoxedFuture; | ||
|
||
pub const CAPITAL: &str = "The capital of Canada is Ottawa."; | ||
pub const CAPITAL_OF_PORTUGAL: &str = "The capital of Portugal is Lisbon."; | ||
pub const DEFAULT_ANSWER: &str = "The answer is 42."; | ||
pub const LONG_ANSWER: &str = "Call me Ishmael. Some years ago—never mind how long precisely—having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. It is a way I have of driving off the spleen and regulating circulation. Whenever I find myself growing grim about the mouth; whenever it is a damp, drizzly November in my soul; whenever I find myself involuntarily pausing before coffin warehouses, and bringing up the rear of every funeral I meet; and especially whenever my hypos get such an upper hand of me, that it requires a strong moral principle to prevent me from deliberately stepping into the street, and methodically knocking people’s hats off—then, I account it high time to get to sea as soon as I can. There is nothing surprising in this. If they but knew it, almost all men in their degree, some time or other, cherish very nearly the same feelings towards the ocean with me."; | ||
|
||
struct ChatFakerModel {} | ||
|
||
impl ChatFakerModel { | ||
async fn new(_path: impl AsRef<Path>) -> Self { | ||
Self {} | ||
} | ||
|
||
async fn chat_completions(&self, args: CompletionArgs) -> Result<String, LLMEndpointError> { | ||
info!("faking chat completions"); | ||
Ok(completions_for(&args.prompt)) | ||
} | ||
|
||
async fn stream_chat_completions( | ||
&self, | ||
args: CompletionArgs, | ||
) -> Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError> { | ||
info!("faking stream chat completions"); | ||
let msg = completions_for(&args.prompt); | ||
let toks = streamify(&msg); | ||
Ok(Box::new(futures::stream::iter(toks.into_iter()))) | ||
} | ||
|
||
//TODO: implement | ||
async fn embeddings(&self, _inputs: &Vec<String>) -> Result<Vec<Vec<f32>>, LLMEndpointError> { | ||
info!("faking emeddings"); | ||
Ok(vec![]) | ||
} | ||
} | ||
|
||
fn completions_for(prompt: &str) -> String { | ||
let prompt = prompt.to_lowercase(); | ||
if prompt.contains("capital") { | ||
if prompt.contains("portugal") { | ||
return CAPITAL_OF_PORTUGAL.to_string(); | ||
} else { | ||
return CAPITAL.to_string(); | ||
} | ||
} else if prompt.contains("long") { | ||
return LONG_ANSWER.to_string(); | ||
} else { | ||
return DEFAULT_ANSWER.to_string(); | ||
} | ||
} | ||
|
||
fn streamify(msg: &str) -> Vec<String> { | ||
msg.split_whitespace().map(|s| s.to_string()).collect() | ||
} | ||
|
||
/// Faking a large language model endpoint, implementing [`LLMEndpoint`]. | ||
pub struct ChatFakerEndpoint { | ||
/// A map of the models currently loaded into memory, with their path as the key. | ||
models: Arc<DashMap<String, ChatFakerModel>>, | ||
} | ||
|
||
impl ChatFakerEndpoint { | ||
// This is not strictly needed because we have no Unloading models. | ||
// Anyway, it looks more like a real model. | ||
async fn get( | ||
&self, | ||
model_path: impl AsRef<Path>, | ||
) -> dashmap::mapref::one::Ref<String, ChatFakerModel> { | ||
let key = model_path.as_ref().to_string_lossy().to_string(); | ||
|
||
if !self.models.contains_key(&key) { | ||
let model = ChatFakerModel::new(model_path).await; | ||
self.models.insert(key.clone(), model); | ||
} | ||
|
||
// PANIC SAFETY: Just inserted the element if it isn't already inside the map, so must be present in the map | ||
self.models.get(&key).unwrap() | ||
} | ||
|
||
/// Helper `async` function that returns the full chat completions for the specified model and | ||
/// [`CompletionArgs`]. | ||
async fn async_chat_completions( | ||
&self, | ||
model_path: impl AsRef<Path>, | ||
args: CompletionArgs, | ||
) -> Result<String, LLMEndpointError> { | ||
let model = self.get(model_path).await; | ||
model.chat_completions(args).await | ||
} | ||
|
||
/// Helper `async` function that returns the chat completions stream for the specified model and | ||
/// [`CompletionArgs`]. | ||
async fn async_stream_chat_completions( | ||
&self, | ||
model_path: impl AsRef<Path>, | ||
args: CompletionArgs, | ||
) -> Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError> { | ||
let model = self.get(model_path).await; | ||
model.stream_chat_completions(args).await | ||
} | ||
|
||
async fn async_embeddings( | ||
&self, | ||
model_path: impl AsRef<Path>, | ||
inputs: Vec<String>, | ||
) -> Result<Vec<Vec<f32>>, LLMEndpointError> { | ||
let model = self.get(model_path).await; | ||
model.embeddings(&inputs).await | ||
} | ||
} | ||
|
||
impl LLMEndpoint for ChatFakerEndpoint { | ||
fn chat_completions<'a>( | ||
&'a self, | ||
model_path: impl AsRef<Path> + Send + 'a, | ||
args: CompletionArgs, | ||
) -> BoxedFuture<Result<String, LLMEndpointError>> { | ||
let pinned = Box::pin(self.async_chat_completions(model_path, args)); | ||
Box::new(pinned) | ||
} | ||
|
||
fn stream_chat_completions<'a>( | ||
&'a self, | ||
model_path: impl AsRef<Path> + Send + 'a, | ||
args: CompletionArgs, | ||
) -> BoxedFuture<Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError>> { | ||
let pinned = Box::pin(self.async_stream_chat_completions(model_path, args)); | ||
Box::new(pinned) | ||
} | ||
|
||
fn embeddings<'a>( | ||
&'a self, | ||
model_path: impl AsRef<Path> + Send + 'a, | ||
inputs: Vec<String>, | ||
) -> BoxedFuture<Result<Vec<Vec<f32>>, LLMEndpointError>> { | ||
let pinned = Box::pin(self.async_embeddings(model_path, inputs)); | ||
Box::new(pinned) | ||
} | ||
|
||
fn reset(&self) { | ||
self.models.clear(); | ||
} | ||
} | ||
|
||
impl Default for ChatFakerEndpoint { | ||
fn default() -> Self { | ||
let models: Arc<DashMap<String, ChatFakerModel>> = Default::default(); | ||
Self { models } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* Copyright 2023- The Binedge, Lda team. All rights reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
//! Endpoint for the chat faker model RT | ||
|
||
use futures::Stream; | ||
use once_cell::sync::Lazy; | ||
|
||
use edgen_core::llm::{CompletionArgs, LLMEndpoint, LLMEndpointError}; | ||
use edgen_rt_chat_faker::ChatFakerEndpoint; | ||
|
||
use crate::model::Model; | ||
use crate::util::StoppingStream; | ||
|
||
static ENDPOINT: Lazy<ChatFakerEndpoint> = Lazy::new(Default::default); | ||
|
||
pub async fn chat_completion( | ||
model: Model, | ||
args: CompletionArgs, | ||
) -> Result<String, LLMEndpointError> { | ||
ENDPOINT | ||
.chat_completions( | ||
model | ||
.file_path() | ||
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?, | ||
args, | ||
) | ||
.await | ||
} | ||
|
||
pub async fn chat_completion_stream( | ||
model: Model, | ||
args: CompletionArgs, | ||
) -> Result<StoppingStream<Box<dyn Stream<Item = String> + Unpin + Send>>, LLMEndpointError> { | ||
let stream = ENDPOINT | ||
.stream_chat_completions( | ||
model | ||
.file_path() | ||
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?, | ||
args, | ||
) | ||
.await?; | ||
|
||
Ok(StoppingStream::wrap_with_stop_words( | ||
stream, | ||
vec![ | ||
"<|ASSISTANT|>".to_string(), | ||
"<|USER|>".to_string(), | ||
"<|TOOL|>".to_string(), | ||
"<|SYSTEM|>".to_string(), | ||
], | ||
)) | ||
} | ||
|
||
pub async fn embeddings( | ||
model: Model, | ||
input: Vec<String>, | ||
) -> Result<Vec<Vec<f32>>, LLMEndpointError> { | ||
ENDPOINT | ||
.embeddings( | ||
model | ||
.file_path() | ||
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?, | ||
input, | ||
) | ||
.await | ||
} | ||
|
||
// Not needed. Just for completeness. | ||
#[allow(dead_code)] | ||
pub async fn reset_environment() { | ||
ENDPOINT.reset() | ||
} |
Oops, something went wrong.