Skip to content

Commit

Permalink
Merge pull request #130 from edgenai/feat/issue120
Browse files Browse the repository at this point in the history
Feat/issue120: backend-agnostic endpoints
  • Loading branch information
toschoo authored Mar 25, 2024
2 parents 0d8663e + c775bbc commit e9a3cc0
Show file tree
Hide file tree
Showing 15 changed files with 1,074 additions and 349 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
"crates/edgen_async_compat",
"crates/edgen_rt_llama_cpp",
"crates/edgen_rt_whisper_cpp",
"crates/edgen_rt_chat_faker",
"edgen/src-tauri"
]

Expand Down
2 changes: 2 additions & 0 deletions crates/edgen_core/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub enum LLMEndpointError {
SessionCreationFailed(String),
#[error("failed to create embeddings: {0}")]
Embeddings(String), // Embeddings may involve session creation, advancing, and other things, so it should have its own error
#[error("unsuitable endpoint for model: {0}")]
UnsuitableEndpoint(String),
}

#[derive(Debug, Clone)]
Expand Down
16 changes: 16 additions & 0 deletions crates/edgen_rt_chat_faker/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "edgen_rt_chat_faker"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
blake3 = { workspace = true }
dashmap = { workspace = true }
derive_more = { workspace = true }
edgen_core = { path = "../edgen_core" }
futures = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["sync", "rt", "fs"] }
tracing = { workspace = true }
172 changes: 172 additions & 0 deletions crates/edgen_rt_chat_faker/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/* Copyright 2023- The Binedge, Lda team. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//! A fake model RT for chat completions that answers with predefined strings

use std::path::Path;
use std::sync::Arc;

use dashmap::DashMap;
use futures::Stream;
use tracing::info;

use edgen_core::llm::{CompletionArgs, LLMEndpoint, LLMEndpointError};
use edgen_core::BoxedFuture;

pub const CAPITAL: &str = "The capital of Canada is Ottawa.";
pub const CAPITAL_OF_PORTUGAL: &str = "The capital of Portugal is Lisbon.";
pub const DEFAULT_ANSWER: &str = "The answer is 42.";
pub const LONG_ANSWER: &str = "Call me Ishmael. Some years ago—never mind how long precisely—having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. It is a way I have of driving off the spleen and regulating circulation. Whenever I find myself growing grim about the mouth; whenever it is a damp, drizzly November in my soul; whenever I find myself involuntarily pausing before coffin warehouses, and bringing up the rear of every funeral I meet; and especially whenever my hypos get such an upper hand of me, that it requires a strong moral principle to prevent me from deliberately stepping into the street, and methodically knocking people’s hats off—then, I account it high time to get to sea as soon as I can. There is nothing surprising in this. If they but knew it, almost all men in their degree, some time or other, cherish very nearly the same feelings towards the ocean with me.";

struct ChatFakerModel {}

impl ChatFakerModel {
async fn new(_path: impl AsRef<Path>) -> Self {
Self {}
}

async fn chat_completions(&self, args: CompletionArgs) -> Result<String, LLMEndpointError> {
info!("faking chat completions");
Ok(completions_for(&args.prompt))
}

async fn stream_chat_completions(
&self,
args: CompletionArgs,
) -> Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError> {
info!("faking stream chat completions");
let msg = completions_for(&args.prompt);
let toks = streamify(&msg);
Ok(Box::new(futures::stream::iter(toks.into_iter())))
}

//TODO: implement
async fn embeddings(&self, _inputs: &Vec<String>) -> Result<Vec<Vec<f32>>, LLMEndpointError> {
info!("faking emeddings");
Ok(vec![])
}
}

fn completions_for(prompt: &str) -> String {
let prompt = prompt.to_lowercase();
if prompt.contains("capital") {
if prompt.contains("portugal") {
return CAPITAL_OF_PORTUGAL.to_string();
} else {
return CAPITAL.to_string();
}
} else if prompt.contains("long") {
return LONG_ANSWER.to_string();
} else {
return DEFAULT_ANSWER.to_string();
}
}

fn streamify(msg: &str) -> Vec<String> {
msg.split_whitespace().map(|s| s.to_string()).collect()
}

/// Faking a large language model endpoint, implementing [`LLMEndpoint`].
pub struct ChatFakerEndpoint {
/// A map of the models currently loaded into memory, with their path as the key.
models: Arc<DashMap<String, ChatFakerModel>>,
}

impl ChatFakerEndpoint {
// This is not strictly needed because we have no Unloading models.
// Anyway, it looks more like a real model.
async fn get(
&self,
model_path: impl AsRef<Path>,
) -> dashmap::mapref::one::Ref<String, ChatFakerModel> {
let key = model_path.as_ref().to_string_lossy().to_string();

if !self.models.contains_key(&key) {
let model = ChatFakerModel::new(model_path).await;
self.models.insert(key.clone(), model);
}

// PANIC SAFETY: Just inserted the element if it isn't already inside the map, so must be present in the map
self.models.get(&key).unwrap()
}

/// Helper `async` function that returns the full chat completions for the specified model and
/// [`CompletionArgs`].
async fn async_chat_completions(
&self,
model_path: impl AsRef<Path>,
args: CompletionArgs,
) -> Result<String, LLMEndpointError> {
let model = self.get(model_path).await;
model.chat_completions(args).await
}

/// Helper `async` function that returns the chat completions stream for the specified model and
/// [`CompletionArgs`].
async fn async_stream_chat_completions(
&self,
model_path: impl AsRef<Path>,
args: CompletionArgs,
) -> Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError> {
let model = self.get(model_path).await;
model.stream_chat_completions(args).await
}

async fn async_embeddings(
&self,
model_path: impl AsRef<Path>,
inputs: Vec<String>,
) -> Result<Vec<Vec<f32>>, LLMEndpointError> {
let model = self.get(model_path).await;
model.embeddings(&inputs).await
}
}

impl LLMEndpoint for ChatFakerEndpoint {
fn chat_completions<'a>(
&'a self,
model_path: impl AsRef<Path> + Send + 'a,
args: CompletionArgs,
) -> BoxedFuture<Result<String, LLMEndpointError>> {
let pinned = Box::pin(self.async_chat_completions(model_path, args));
Box::new(pinned)
}

fn stream_chat_completions<'a>(
&'a self,
model_path: impl AsRef<Path> + Send + 'a,
args: CompletionArgs,
) -> BoxedFuture<Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError>> {
let pinned = Box::pin(self.async_stream_chat_completions(model_path, args));
Box::new(pinned)
}

fn embeddings<'a>(
&'a self,
model_path: impl AsRef<Path> + Send + 'a,
inputs: Vec<String>,
) -> BoxedFuture<Result<Vec<Vec<f32>>, LLMEndpointError>> {
let pinned = Box::pin(self.async_embeddings(model_path, inputs));
Box::new(pinned)
}

fn reset(&self) {
self.models.clear();
}
}

impl Default for ChatFakerEndpoint {
fn default() -> Self {
let models: Arc<DashMap<String, ChatFakerModel>> = Default::default();
Self { models }
}
}
1 change: 1 addition & 0 deletions crates/edgen_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ console-subscriber = { workspace = true }
dashmap = { workspace = true }
derive_more = { workspace = true }
edgen_core = { path = "../edgen_core" }
edgen_rt_chat_faker = { path = "../edgen_rt_chat_faker" }
edgen_rt_llama_cpp = { path = "../edgen_rt_llama_cpp" }
edgen_rt_whisper_cpp = { path = "../edgen_rt_whisper_cpp" }
either = { workspace = true, features = ["serde"] }
Expand Down
82 changes: 82 additions & 0 deletions crates/edgen_server/src/chat_faker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2023- The Binedge, Lda team. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//! Endpoint for the chat faker model RT

use futures::Stream;
use once_cell::sync::Lazy;

use edgen_core::llm::{CompletionArgs, LLMEndpoint, LLMEndpointError};
use edgen_rt_chat_faker::ChatFakerEndpoint;

use crate::model::Model;
use crate::util::StoppingStream;

static ENDPOINT: Lazy<ChatFakerEndpoint> = Lazy::new(Default::default);

pub async fn chat_completion(
model: Model,
args: CompletionArgs,
) -> Result<String, LLMEndpointError> {
ENDPOINT
.chat_completions(
model
.file_path()
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?,
args,
)
.await
}

pub async fn chat_completion_stream(
model: Model,
args: CompletionArgs,
) -> Result<StoppingStream<Box<dyn Stream<Item = String> + Unpin + Send>>, LLMEndpointError> {
let stream = ENDPOINT
.stream_chat_completions(
model
.file_path()
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?,
args,
)
.await?;

Ok(StoppingStream::wrap_with_stop_words(
stream,
vec![
"<|ASSISTANT|>".to_string(),
"<|USER|>".to_string(),
"<|TOOL|>".to_string(),
"<|SYSTEM|>".to_string(),
],
))
}

pub async fn embeddings(
model: Model,
input: Vec<String>,
) -> Result<Vec<Vec<f32>>, LLMEndpointError> {
ENDPOINT
.embeddings(
model
.file_path()
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?,
input,
)
.await
}

// Not needed. Just for completeness.
#[allow(dead_code)]
pub async fn reset_environment() {
ENDPOINT.reset()
}
Loading

0 comments on commit e9a3cc0

Please sign in to comment.