diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7ac0cbc..d0ed4c6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -54,7 +54,9 @@ jobs: - name: Add rust target (macOS) if: ${{ matrix.runner == 'macos-latest' }} - run: rustup target add aarch64-apple-darwin + run: | + rustup target add x86_64-apple-darwin + rustup target add aarch64-apple-darwin - name: Add rust target (Other) if: ${{ matrix.runner != 'macos-latest' }} diff --git a/Cargo.lock b/Cargo.lock index 0098f10..f9cfbc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -591,7 +591,7 @@ dependencies = [ "bitflags 2.4.2", "cexpr", "clang-sys", - "itertools 0.12.1", + "itertools 0.11.0", "lazy_static", "lazycell", "log", @@ -1388,6 +1388,17 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dbus" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bb21987b9fb1613058ba3843121dd18b163b254d8a6e797e144cbac14d96d1b" +dependencies = [ + "libc", + "libdbus-sys", + "winapi", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1561,6 +1572,7 @@ dependencies = [ "tauri-build", "tokio", "tracing", + "winapi", ] [[package]] @@ -1587,6 +1599,7 @@ dependencies = [ "derive_more", "directories", "edgen_async_compat", + "either", "futures", "notify", "num_cpus", @@ -1609,6 +1622,7 @@ dependencies = [ name = "edgen_rt_chat_faker" version = "0.1.0" dependencies = [ + "async-trait", "blake3", "dashmap", "derive_more", @@ -1625,6 +1639,7 @@ version = "0.1.0" dependencies = [ "async-trait", "candle-core", + "candle-nn", "candle-transformers", "edgen_core", "image 0.25.1", @@ -1639,6 +1654,7 @@ dependencies = [ name = "edgen_rt_llama_cpp" version = "0.1.0" dependencies = [ + "async-trait", "blake3", "dashmap", "derive_more", @@ -1654,6 +1670,7 @@ dependencies = [ name = "edgen_rt_whisper_cpp" version = "0.1.0" dependencies = [ + "async-trait", "dashmap", "derive_more", "edgen_core", @@ -1691,7 +1708,9 @@ dependencies = [ "levenshtein", "once_cell", "pin-project", + "rand 0.8.5", "reqwest 0.12.3", + "reqwest-eventsource", "rubato", "serde", "serde_derive", @@ -1845,6 +1864,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -2093,6 +2123,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -2832,7 +2868,7 @@ dependencies = [ "httpdate", "itoa 1.0.10", "pin-project-lite", - "socket2 0.5.5", + "socket2 0.4.10", "tokio", "tower-service", "tracing", @@ -3315,6 +3351,16 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libdbus-sys" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06085512b750d640299b79be4bad3d2fa90a9c00b1fd9e1b46364f66f0485c72" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "libfuzzer-sys" version = "0.4.7" @@ -3389,8 +3435,8 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama_cpp" -version = "0.3.1" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#9dd2b2229205096645e76a1712c6b73a7a781dd4" +version = "0.3.2" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#55eb9691b13d07f56eaa68992d5e6accaa281691" dependencies = [ "derive_more", "futures", @@ -3403,8 +3449,8 @@ dependencies = [ [[package]] name = "llama_cpp_sys" -version = "0.3.1" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#9dd2b2229205096645e76a1712c6b73a7a781dd4" +version = "0.3.2" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#55eb9691b13d07f56eaa68992d5e6accaa281691" dependencies = [ "ash", "bindgen", @@ -3980,13 +4026,14 @@ dependencies = [ [[package]] name = "opener" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c62dcb6174f9cb326eac248f07e955d5d559c272730b6c03e396b443b562788" +checksum = "f9901cb49d7fc923b256db329ee26ffed69130bf05d74b9efdd1875c92d6af01" dependencies = [ "bstr", + "dbus", "normpath", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -4882,14 +4929,32 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg 0.52.0", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest 0.12.3", + "thiserror", +] + [[package]] name = "reserve-port" version = "2.0.1" diff --git a/crates/edgen_core/Cargo.toml b/crates/edgen_core/Cargo.toml index a011dde..ba145f3 100644 --- a/crates/edgen_core/Cargo.toml +++ b/crates/edgen_core/Cargo.toml @@ -9,6 +9,7 @@ async-trait = { workspace = true } dashmap = { workspace = true } directories = { workspace = true } derive_more = { workspace = true } +either = { workspace = true } edgen_async_compat = { path = "../edgen_async_compat", features = ["runtime-tokio"] } notify = { workspace = true } num_cpus = { workspace = true } diff --git a/crates/edgen_core/src/lib.rs b/crates/edgen_core/src/lib.rs index e6cd4df..d3f06f5 100644 --- a/crates/edgen_core/src/lib.rs +++ b/crates/edgen_core/src/lib.rs @@ -18,7 +18,6 @@ extern crate alloc; -use std::future::Future; use std::time::Duration; pub mod llm; @@ -29,9 +28,6 @@ pub mod settings; pub mod image_generation; pub mod perishable; -/// A generic [`Box`]ed [`Future`], used to emulate `async` functions in traits. -pub type BoxedFuture<'a, T> = Box + Send + Unpin + 'a>; - /// Return the [`Duration`] that cleanup threads should wait before looking for and freeing unused /// resources, after last doing so. pub fn cleanup_interval() -> Duration { diff --git a/crates/edgen_core/src/llm.rs b/crates/edgen_core/src/llm.rs index d6692f3..18ae019 100644 --- a/crates/edgen_core/src/llm.rs +++ b/crates/edgen_core/src/llm.rs @@ -10,15 +10,17 @@ * limitations under the License. */ +use core::fmt::{Display, Formatter}; use core::time::Duration; +use std::collections::HashMap; use std::path::Path; +use derive_more::{Deref, DerefMut, From}; +use either::Either; use futures::Stream; use serde::Serialize; use thiserror::Error; -use crate::BoxedFuture; - /// The context tag marking the start of generated dialogue. pub const ASSISTANT_TAG: &str = "<|ASSISTANT|>"; @@ -45,52 +47,305 @@ pub enum LLMEndpointError { UnsuitableEndpoint(String), } -#[derive(Debug, Clone)] -pub struct CompletionArgs { - pub prompt: String, - pub one_shot: bool, - pub seed: Option, - pub frequency_penalty: f32, - pub context_hint: Option, +/// The plaintext or image content of a [`ChatMessage`] within a [`CreateChatCompletionRequest`]. +/// +/// This can be plain text or a URL to an image. +#[derive(Debug)] +pub enum ContentPart { + /// Plain text. + Text { + /// The plain text. + text: String, + }, + /// A URL to an image. + ImageUrl { + /// The URL. + url: String, + + /// A description of the image behind the URL, if any. + detail: Option, + }, } -impl Default for CompletionArgs { - fn default() -> Self { - Self { - prompt: "".to_string(), - one_shot: false, - seed: None, - frequency_penalty: 0.0, - context_hint: None, +impl Display for ContentPart { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ContentPart::Text { text } => write!(f, "{}", text), + ContentPart::ImageUrl { url, detail } => { + if let Some(detail) = detail { + write!(f, " ({})", url, detail) + } else { + write!(f, "", url) + } + } } } } +/// A description of a function provided to a large language model, to assist it in interacting +/// with the outside world. +/// +/// This is included in [`AssistantToolCall`]s within [`ChatMessage`]s. +#[derive(Debug)] +pub struct AssistantFunctionStub { + /// The name of the function from the assistant's point of view. + pub name: String, + + /// The arguments passed into the function. + pub arguments: String, +} + +/// A description of a function that an assistant called. +/// +/// This is included in [`ChatMessage`]s when the `tool_calls` field is present. +#[derive(Debug)] +pub struct AssistantToolCall { + /// A unique identifier for the invocation of this function. + pub id: String, + + /// The type of the invoked tool. + /// + /// OpenAI currently specifies this to always be `function`, but more variants may be added + /// in the future. + pub type_: String, + + /// The invoked function. + pub function: AssistantFunctionStub, +} + +/// A chat message in a multi-user dialogue. +/// +/// This is as context for a [`CreateChatCompletionRequest`]. +#[derive(Debug)] +pub enum ChatMessage { + /// A message from the system. This is typically used to set the initial system prompt; for + /// example, "you are a helpful assistant". + System { + /// The content of the message, if any. + content: Option, + + /// If present, a name for the system. + name: Option, + }, + /// A message from a user. + User { + /// The content of the message. This can be a sequence of multiple plain text or image + /// parts. + content: Either>, + + /// If present, a name for the user. + name: Option, + }, + /// A message from an assistant. + Assistant { + /// The plaintext message of the message, if any. + content: Option, + + /// The name of the assistant, if any. + name: Option, + + /// If the assistant used any tools in generating this message, the tools that the assistant + /// used. + tool_calls: Option>, + }, + /// A message from a tool accessible by other peers in the dialogue. + Tool { + /// The plaintext that the tool generated, if any. + content: Option, + + /// A unique identifier for the specific invocation that generated this message. + tool_call_id: String, + }, +} + +/* + +/// A tool made available to an assistant that invokes a named function. +/// +/// This is included in [`ToolStub`]s within [`CreateChatCompletionRequest`]s. +#[derive(Debug)] +pub struct FunctionStub<'a> { + /// A human-readable description of what the tool does. + pub description: Option>, + + /// The name of the tool. + pub name: Cow<'a, str>, + + /// A [JSON schema][json-schema] describing the parameters that the tool accepts. + /// + /// [json-schema]: https://json-schema.org/ + pub parameters: serde_json::Value, +} + +/// A tool made available to an assistant. +/// +/// At present, this can only be a [`FunctionStub`], but this enum is marked `#[non_exhaustive]` +/// for the (likely) event that more variants are added in the future. +/// +/// This is included in [`CreateChatCompletionRequest`]s. +#[derive(Debug)] +#[non_exhaustive] +pub enum ToolStub<'a> { + /// A named function that can be invoked by an assistant. + Function { + /// The named function. + function: FunctionStub<'a>, + }, +} + +*/ + +/// A sequence of chat messages in a [`CreateChatCompletionRequest`]. +/// +/// This implements [`Display`] to generate a transcript of the chat messages compatible with most +/// LLaMa-based models. +#[derive(Debug, Default, Deref, DerefMut, From)] +pub struct ChatMessages( + #[deref] + #[deref_mut] + pub Vec, +); + +impl Display for ChatMessages { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for message in &self.0 { + match message { + ChatMessage::System { + content: Some(data), + .. + } => { + write!(f, "{SYSTEM_TAG}{data}")?; + } + ChatMessage::User { + content: Either::Left(data), + .. + } => { + write!(f, "{USER_TAG}{data}")?; + } + ChatMessage::User { + content: Either::Right(data), + .. + } => { + write!(f, "{USER_TAG}")?; + + for part in data { + write!(f, "{part}")?; + } + } + ChatMessage::Assistant { + content: Some(data), + .. + } => { + write!(f, "{ASSISTANT_TAG}{data}")?; + } + ChatMessage::Tool { + content: Some(data), + .. + } => { + write!(f, "{TOOL_TAG}{data}")?; + } + _ => {} + } + } + + Ok(()) + } +} + +/// A request to generate chat completions for the provided context. +#[derive(Debug)] +pub struct CompletionArgs { + /// The messages that have been sent in the dialogue so far. + pub messages: ChatMessages, + + /// A number in `[-2.0, 2.0]`. A higher number decreases the likelihood that the model + /// repeats itself. + pub frequency_penalty: Option, + + /// A map of token IDs to `[-100.0, +100.0]`. Adds a percentage bias to those tokens before + /// sampling; a value of `-100.0` prevents the token from being selected at all. + /// + /// You could use this to, for example, prevent the model from emitting profanity. + pub logit_bias: Option>, + + /// The maximum number of tokens to generate. If `None`, terminates at the first stop token + /// or the end of sentence. + pub max_tokens: Option, + + /// How many choices to generate for each token in the output. `1` by default. You can use + /// this to generate several sets of completions for the same prompt. + pub n: Option, + + /// A number in `[-2.0, 2.0]`. Positive values "increase the model's likelihood to talk about + /// new topics." + pub presence_penalty: Option, + + /// An RNG seed for the session. Random by default. + pub seed: Option, + + /// A stop phrase or set of stop phrases. + /// + /// The server will pause emitting completions if it appears to be generating a stop phrase, + /// and will terminate completions if a full stop phrase is detected. + /// + /// Stop phrases are never emitted to the client. + pub stop: Option>>, + + /// The sampling temperature, in `[0.0, 2.0]`. Higher values make the output more random. + pub temperature: Option, + + /// Nucleus sampling. If you set this value to 10%, only the top 10% of tokens are used for + /// sampling, preventing sampling of very low-probability tokens. + pub top_p: Option, + + /// A list of tools made available to the model. + // pub tools: Option>>, + + /// If present, the tool that the user has chosen to use. + /// + /// OpenAI states: + /// + /// - `none` prevents any tool from being used, + /// - `auto` allows any tool to be used, or + /// - you can provide a description of the tool entirely instead of a name. + // pub tool_choice: Option, ToolStub<'a>>>, + + /// Indicate if this is an isolated request, with no associated past or future context. This may allow for + /// optimisations in some implementations. Default: `false` + pub one_shot: Option, + + /// A hint for how big a context will be. + /// + /// # Warning + /// An unsound hint may severely drop performance and/or inference quality, and in some cases even cause Edgen + /// to crash. Do not set this value unless you know what you are doing. + pub context_hint: Option, +} + /// A large language model endpoint, that is, an object that provides various ways to interact with /// a large language model. +#[async_trait::async_trait] pub trait LLMEndpoint { - /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually - /// contain the prompt completion in [`String`] form. - fn chat_completions<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, + /// Given a prompt with several arguments, return a prompt completion in [`String`] form. + async fn chat_completions( + &self, + model_path: impl AsRef + Send, args: CompletionArgs, - ) -> BoxedFuture>; - - /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually - /// contain a [`Stream`] of [`String`] chunks of the prompt completion, acquired as they get - /// processed. - fn stream_chat_completions<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, + ) -> Result; + + /// Given a prompt with several arguments, return a [`Stream`] of [`String`] chunks of the + /// prompt completion, acquired as they get processed. + async fn stream_chat_completions( + &self, + model_path: impl AsRef + Send, args: CompletionArgs, - ) -> BoxedFuture + Unpin + Send>, LLMEndpointError>>; + ) -> Result + Unpin + Send>, LLMEndpointError>; - fn embeddings<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, + async fn embeddings( + &self, + model_path: impl AsRef + Send, inputs: Vec, - ) -> BoxedFuture>, LLMEndpointError>>; + ) -> Result>, LLMEndpointError>; /// Unloads everything from memory. fn reset(&self); diff --git a/crates/edgen_core/src/whisper.rs b/crates/edgen_core/src/whisper.rs index dc4bddc..e589768 100644 --- a/crates/edgen_core/src/whisper.rs +++ b/crates/edgen_core/src/whisper.rs @@ -18,8 +18,6 @@ use thiserror::Error; use utoipa::ToSchema; use uuid::Uuid; -use crate::BoxedFuture; - #[derive(Serialize, Error, Debug)] pub enum WhisperEndpointError { #[error("failed to advance context: {0}")] @@ -45,14 +43,14 @@ pub struct TranscriptionArgs { pub session: Option, } +#[async_trait::async_trait] pub trait WhisperEndpoint { - /// Given an audio segment with several arguments, return a [`Box`]ed [`Future`] which may - /// eventually contain its transcription in [`String`] form. - fn transcription<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, + /// Given an audio segment with several arguments, return a transcription in [`String`] form. + async fn transcription( + &self, + model_path: impl AsRef + Send, args: TranscriptionArgs, - ) -> BoxedFuture), WhisperEndpointError>>; + ) -> Result<(String, Option), WhisperEndpointError>; /// Unloads everything from memory. fn reset(&self); diff --git a/crates/edgen_rt_chat_faker/Cargo.toml b/crates/edgen_rt_chat_faker/Cargo.toml index ec05cb8..41d7f3e 100644 --- a/crates/edgen_rt_chat_faker/Cargo.toml +++ b/crates/edgen_rt_chat_faker/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = { workspace = true } blake3 = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true } diff --git a/crates/edgen_rt_chat_faker/src/lib.rs b/crates/edgen_rt_chat_faker/src/lib.rs index b06a3e3..6a9cae5 100644 --- a/crates/edgen_rt_chat_faker/src/lib.rs +++ b/crates/edgen_rt_chat_faker/src/lib.rs @@ -20,7 +20,6 @@ use futures::Stream; use tracing::info; use edgen_core::llm::{CompletionArgs, LLMEndpoint, LLMEndpointError}; -use edgen_core::BoxedFuture; pub const CAPITAL: &str = "The capital of Canada is Ottawa."; pub const CAPITAL_OF_PORTUGAL: &str = "The capital of Portugal is Lisbon."; @@ -34,23 +33,25 @@ impl ChatFakerModel { Self {} } - async fn chat_completions(&self, args: CompletionArgs) -> Result { + async fn chat_completions(&self, args: &CompletionArgs) -> Result { info!("faking chat completions"); - Ok(completions_for(&args.prompt)) + let prompt = format!("{}<|ASSISTANT|>", args.messages); + Ok(completions_for(&prompt)) } async fn stream_chat_completions( &self, - args: CompletionArgs, + args: &CompletionArgs, ) -> Result + Unpin + Send>, LLMEndpointError> { info!("faking stream chat completions"); - let msg = completions_for(&args.prompt); + let prompt = format!("{}<|ASSISTANT|>", args.messages); + let msg = completions_for(&prompt); let toks = streamify(&msg); Ok(Box::new(futures::stream::iter(toks.into_iter()))) } //TODO: implement - async fn embeddings(&self, _inputs: &Vec) -> Result>, LLMEndpointError> { + async fn embeddings(&self, _inputs: &[String]) -> Result>, LLMEndpointError> { info!("faking emeddings"); Ok(vec![]) } @@ -98,66 +99,36 @@ impl ChatFakerEndpoint { // PANIC SAFETY: Just inserted the element if it isn't already inside the map, so must be present in the map self.models.get(&key).unwrap() } +} - /// Helper `async` function that returns the full chat completions for the specified model and - /// [`CompletionArgs`]. - async fn async_chat_completions( +#[async_trait::async_trait] +impl LLMEndpoint for ChatFakerEndpoint { + async fn chat_completions( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, args: CompletionArgs, ) -> Result { let model = self.get(model_path).await; - model.chat_completions(args).await + model.chat_completions(&args).await } - /// Helper `async` function that returns the chat completions stream for the specified model and - /// [`CompletionArgs`]. - async fn async_stream_chat_completions( + async fn stream_chat_completions( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, args: CompletionArgs, ) -> Result + Unpin + Send>, LLMEndpointError> { let model = self.get(model_path).await; - model.stream_chat_completions(args).await + model.stream_chat_completions(&args).await } - async fn async_embeddings( + async fn embeddings( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, inputs: Vec, ) -> Result>, LLMEndpointError> { let model = self.get(model_path).await; model.embeddings(&inputs).await } -} - -impl LLMEndpoint for ChatFakerEndpoint { - fn chat_completions<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - args: CompletionArgs, - ) -> BoxedFuture> { - let pinned = Box::pin(self.async_chat_completions(model_path, args)); - Box::new(pinned) - } - - fn stream_chat_completions<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - args: CompletionArgs, - ) -> BoxedFuture + Unpin + Send>, LLMEndpointError>> { - let pinned = Box::pin(self.async_stream_chat_completions(model_path, args)); - Box::new(pinned) - } - - fn embeddings<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - inputs: Vec, - ) -> BoxedFuture>, LLMEndpointError>> { - let pinned = Box::pin(self.async_embeddings(model_path, inputs)); - Box::new(pinned) - } fn reset(&self) { self.models.clear(); diff --git a/crates/edgen_rt_image_generation_candle/Cargo.toml b/crates/edgen_rt_image_generation_candle/Cargo.toml index 9d8f557..ff3ce20 100644 --- a/crates/edgen_rt_image_generation_candle/Cargo.toml +++ b/crates/edgen_rt_image_generation_candle/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] async-trait = { workspace = true } candle-core = "0.4.1" +candle-nn = "0.4.1" candle-transformers = "0.4.1" edgen_core = { path = "../edgen_core" } image = "0.25.1" diff --git a/crates/edgen_rt_image_generation_candle/src/lib.rs b/crates/edgen_rt_image_generation_candle/src/lib.rs index c238b7e..7d25f3b 100644 --- a/crates/edgen_rt_image_generation_candle/src/lib.rs +++ b/crates/edgen_rt_image_generation_candle/src/lib.rs @@ -4,18 +4,18 @@ use std::path::Path; use candle_core::backend::BackendDevice; use candle_core::{CudaDevice, DType, Device, IndexOp, Module, Tensor, D}; -use candle_transformers::models::stable_diffusion; use candle_transformers::models::stable_diffusion::vae::AutoEncoderKL; +use candle_transformers::models::{stable_diffusion, wuerstchen}; use image::{ImageBuffer, ImageError, ImageFormat, Rgb}; +use rand::random; use thiserror::Error; use tokenizers::Tokenizer; +use tracing::{debug, info, info_span, warn}; use edgen_core::image_generation::{ ImageGenerationArgs, ImageGenerationEndpoint, ImageGenerationEndpointError, ModelFiles, }; use edgen_core::settings::{DevicePolicy, SETTINGS}; -use rand::random; -use tracing::{debug, info, info_span, warn}; #[derive(Error, Debug)] enum CandleError { @@ -37,7 +37,7 @@ enum CandleError { EncodeWriteFailed(#[from] IntoInnerError>>>), } -fn text_embeddings( +fn sd_text_embeddings( prompt: &str, uncond_prompt: &str, tokenizer: impl AsRef, @@ -106,7 +106,7 @@ fn text_embeddings( Ok(text_embeddings) } -fn to_bitmap( +fn sd_to_bitmap( vae: &AutoEncoderKL, latents: &Tensor, vae_scale: f64, @@ -137,22 +137,14 @@ fn to_bitmap( Ok(res) } -fn generate_image( +fn sd_generate_image( model: ModelFiles, args: ImageGenerationArgs, - device_policy: &DevicePolicy, + device: Device, ) -> Result>, CandleError> { - let _span = info_span!("gen_image", images = args.images, steps = args.steps).entered(); + let _span = info_span!("sd_gen_image", images = args.images, steps = args.steps).entered(); let config = stable_diffusion::StableDiffusionConfig::v2_1(None, args.height, args.width); let scheduler = config.build_scheduler(args.steps)?; - let device = match device_policy { - DevicePolicy::AlwaysCpu { .. } => Device::Cpu, - DevicePolicy::AlwaysDevice { .. } => Device::Cuda(CudaDevice::new(0)?), - _ => { - warn!("Unknown device policy, executing on CPU"); - Device::Cpu - } - }; let use_guide_scale = args.guidance_scale > 1.0; let dtype = DType::F16; let bsize = 1; @@ -172,7 +164,7 @@ fn generate_image( } else { model.clip2_weights.as_ref().unwrap() }; - text_embeddings( + sd_text_embeddings( &args.prompt, &args.uncond_prompt, &model.tokenizer, @@ -237,12 +229,217 @@ fn generate_image( latents = scheduler.step(&noise_pred, timestep, &latents)?; } - images.extend(to_bitmap(&vae, &latents, args.vae_scale, bsize)?) + images.extend(sd_to_bitmap(&vae, &latents, args.vae_scale, bsize)?) } Ok(images) } +const PRIOR_GUIDANCE_SCALE: f64 = 4.0; +const RESOLUTION_MULTIPLE: f64 = 42.67; +const LATENT_DIM_SCALE: f64 = 10.67; +const PRIOR_CIN: usize = 16; +const DECODER_CIN: usize = 4; + +fn ws_encode_prompt( + prompt: &str, + uncond_prompt: Option<&str>, + tokenizer: impl AsRef, + clip_weights: impl AsRef, + clip_config: stable_diffusion::clip::Config, + device: &Device, +) -> Result { + let tokenizer = + Tokenizer::from_file(tokenizer).map_err(|e| CandleError::Tokenizer(e.to_string()))?; + let pad_id = match &clip_config.pad_with { + Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), + None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), + }; + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(|e| CandleError::Tokenizer(e.to_string()))? + .get_ids() + .to_vec(); + let tokens_len = tokens.len(); + while tokens.len() < clip_config.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + let text_model = + stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; + let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?; + match uncond_prompt { + None => Ok(text_embeddings), + Some(uncond_prompt) => { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(|e| CandleError::Tokenizer(e.to_string()))? + .get_ids() + .to_vec(); + let uncond_tokens_len = uncond_tokens.len(); + while uncond_tokens.len() < clip_config.max_position_embeddings { + uncond_tokens.push(pad_id) + } + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + + let uncond_embeddings = + text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; + let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; + Ok(text_embeddings) + } + } +} + +struct Wuerstchen { + tokenizer: std::path::PathBuf, + prior_tokenizer: std::path::PathBuf, + clip: std::path::PathBuf, + prior_clip: std::path::PathBuf, + decoder: std::path::PathBuf, + vq_gan: std::path::PathBuf, + prior: std::path::PathBuf, +} + +#[allow(dead_code)] +fn ws_generate_image( + paths: Wuerstchen, + args: ImageGenerationArgs, + device: Device, +) -> Result>, CandleError> { + let _span = info_span!("ws_gen_image", images = args.images, steps = args.steps).entered(); + let height = args.height.unwrap_or(1024); + let width = args.width.unwrap_or(1024); + + let prior_text_embeddings = ws_encode_prompt( + &args.prompt, + Some(&args.uncond_prompt), + &paths.prior_tokenizer, + &paths.prior_clip, + stable_diffusion::clip::Config::wuerstchen_prior(), + &device, + )?; + + let text_embeddings = ws_encode_prompt( + &args.prompt, + None, + &paths.tokenizer, + &paths.clip, + stable_diffusion::clip::Config::wuerstchen(), + &device, + )?; + + let b_size = 1; + let image_embeddings = { + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json + let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let mut latents = Tensor::randn( + 0f32, + 1f32, + (b_size, PRIOR_CIN, latent_height, latent_width), + &device, + )?; + + let prior = { + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors( + &[&paths.prior], + DType::F32, + &device, + )? + }; + wuerstchen::prior::WPrior::new(PRIOR_CIN, 1536, 1280, 64, 32, 24, false, vb)? + }; + let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = prior_scheduler.timesteps(); + let timesteps = ×teps[..timesteps.len() - 1]; + for (index, &t) in timesteps.iter().enumerate() { + debug!("Prior de-noising step {index}"); + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; + let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; + latents = prior_scheduler.step(&noise_pred, t, &latents)?; + } + ((latents * 42.)? - 1.)? + }; + + let vqgan = { + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[&paths.vq_gan], DType::F32, &device)? + }; + wuerstchen::paella_vq::PaellaVQ::new(vb)? + }; + + // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json + let decoder = { + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[&paths.decoder], DType::F32, &device)? + }; + wuerstchen::diffnext::WDiffNeXt::new( + DECODER_CIN, + DECODER_CIN, + 64, + 1024, + 1024, + 2, + false, + vb, + )? + }; + + let mut res = vec![]; + res.reserve(args.images as usize); + for idx in 0..args.images { + let _span = info_span!("image", image_index = idx).entered(); + info!("Generating image"); + // https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json + let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize; + let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize; + + let mut latents = Tensor::randn( + 0f32, + 1f32, + (b_size, DECODER_CIN, latent_height, latent_width), + &device, + )?; + + let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?; + let timesteps = scheduler.timesteps(); + let timesteps = ×teps[..timesteps.len() - 1]; + for (index, &t) in timesteps.iter().enumerate() { + debug!("Image generation step {index}"); + let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?; + let noise_pred = + decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?; + latents = scheduler.step(&noise_pred, t, &latents)?; + } + let image = vqgan.decode(&(&latents * 0.3764)?)?; + let image = (image.clamp(0f32, 1f32)? * 255.)? + .to_dtype(DType::U8)? + .i(0)?; + let (channel, height, width) = image.dims3()?; + if channel != 3 { + return Err(CandleError::BadDims { + dims: channel, + expected: 3, + }); + } + let img = image.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let buf = ImageBuffer::, _>::from_vec(width as u32, height as u32, pixels) + .ok_or(CandleError::BadOutput)?; + let mut encoded = BufWriter::new(Cursor::new(Vec::new())); + buf.write_to(&mut encoded, ImageFormat::Png)?; + res.push(encoded.into_inner()?.into_inner()); + } + Ok(res) +} + pub struct CandleImageGenerationEndpoint {} #[async_trait::async_trait] @@ -252,8 +449,18 @@ impl ImageGenerationEndpoint for CandleImageGenerationEndpoint { model: ModelFiles, args: ImageGenerationArgs, ) -> Result>, ImageGenerationEndpointError> { - let policy = SETTINGS.read().await.read().await.gpu_policy.clone(); - Ok(generate_image(model, args, &policy)?) + let device = match SETTINGS.read().await.read().await.gpu_policy { + DevicePolicy::AlwaysCpu { .. } => Device::Cpu, + DevicePolicy::AlwaysDevice { .. } => { + Device::Cuda(CudaDevice::new(0).map_err(|e| CandleError::Candle(e))?) + } + _ => { + warn!("Unknown device policy, executing on CPU"); + Device::Cpu + } + }; + + Ok(sd_generate_image(model, args, device)?) } } diff --git a/crates/edgen_rt_llama_cpp/Cargo.toml b/crates/edgen_rt_llama_cpp/Cargo.toml index 5b0d33b..eba4ffe 100644 --- a/crates/edgen_rt_llama_cpp/Cargo.toml +++ b/crates/edgen_rt_llama_cpp/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +async-trait = { workspace = true } blake3 = { workspace = true } dashmap = { workspace = true } derive_more = { workspace = true } diff --git a/crates/edgen_rt_llama_cpp/src/lib.rs b/crates/edgen_rt_llama_cpp/src/lib.rs index 1ca1874..7f8634e 100644 --- a/crates/edgen_rt_llama_cpp/src/lib.rs +++ b/crates/edgen_rt_llama_cpp/src/lib.rs @@ -31,13 +31,13 @@ use tokio::time::{interval, MissedTickBehavior}; use tokio::{select, spawn}; use tracing::{error, info}; +use edgen_core::cleanup_interval; use edgen_core::llm::{ inactive_llm_session_ttl, inactive_llm_ttl, CompletionArgs, LLMEndpoint, LLMEndpointError, ASSISTANT_TAG, SYSTEM_TAG, TOOL_TAG, USER_TAG, }; use edgen_core::perishable::{ActiveSignal, Perishable, PerishableReadGuard, PerishableWriteGuard}; use edgen_core::settings::{DevicePolicy, SETTINGS}; -use edgen_core::{cleanup_interval, BoxedFuture}; // TODO this should be in settings const SINGLE_MESSAGE_LIMIT: usize = 4096; @@ -70,66 +70,36 @@ impl LlamaCppEndpoint { // PANIC SAFETY: Just inserted the element if it isn't already inside the map, so must be present in the map self.models.get(&key).unwrap() } +} - /// Helper `async` function that returns the full chat completions for the specified model and - /// [`CompletionArgs`]. - async fn async_chat_completions( +#[async_trait::async_trait] +impl LLMEndpoint for LlamaCppEndpoint { + async fn chat_completions( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, args: CompletionArgs, ) -> Result { let model = self.get(model_path).await; model.chat_completions(args).await } - /// Helper `async` function that returns the chat completions stream for the specified model and - /// [`CompletionArgs`]. - async fn async_stream_chat_completions( + async fn stream_chat_completions( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, args: CompletionArgs, ) -> Result + Unpin + Send>, LLMEndpointError> { let model = self.get(model_path).await; model.stream_chat_completions(args).await } - async fn async_embeddings( + async fn embeddings( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, inputs: Vec, ) -> Result>, LLMEndpointError> { let model = self.get(model_path).await; model.embeddings(inputs).await } -} - -impl LLMEndpoint for LlamaCppEndpoint { - fn chat_completions<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - args: CompletionArgs, - ) -> BoxedFuture> { - let pinned = Box::pin(self.async_chat_completions(model_path, args)); - Box::new(pinned) - } - - fn stream_chat_completions<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - args: CompletionArgs, - ) -> BoxedFuture + Unpin + Send>, LLMEndpointError>> { - let pinned = Box::pin(self.async_stream_chat_completions(model_path, args)); - Box::new(pinned) - } - - fn embeddings<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - inputs: Vec, - ) -> BoxedFuture>, LLMEndpointError>> { - let pinned = Box::pin(self.async_embeddings(model_path, inputs)); - Box::new(pinned) - } fn reset(&self) { self.models.clear(); @@ -238,7 +208,9 @@ impl UnloadingModel { async fn chat_completions(&self, args: CompletionArgs) -> Result { let (_model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?; - if args.one_shot { + let prompt = format!("{}<|ASSISTANT|>", args.messages); + + if args.one_shot.unwrap_or(false) { info!("Allocating one-shot LLM session"); let mut params = SessionParams::default(); let threads = SETTINGS.read().await.read().await.auto_threads(false); @@ -254,7 +226,7 @@ impl UnloadingModel { .map_err(move |e| LLMEndpointError::SessionCreationFailed(e.to_string()))?; session - .advance_context_async(args.prompt) + .advance_context_async(prompt) .await .map_err(move |e| LLMEndpointError::Advance(e.to_string()))?; @@ -265,7 +237,7 @@ impl UnloadingModel { Ok(handle.into_string_async().await) } else { - let (session, mut id, new_context) = self.take_chat_session(&args.prompt).await; + let (session, mut id, new_context) = self.take_chat_session(&prompt).await; let (_session_signal, handle) = { let (session_signal, mut session_guard) = @@ -301,7 +273,9 @@ impl UnloadingModel { ) -> Result + Unpin + Send>, LLMEndpointError> { let (model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?; - if args.one_shot { + let prompt = format!("{}<|ASSISTANT|>", args.messages); + + if args.one_shot.unwrap_or(false) { info!("Allocating one-shot LLM session"); let mut params = SessionParams::default(); let threads = SETTINGS.read().await.read().await.auto_threads(false); @@ -318,10 +292,10 @@ impl UnloadingModel { let sampler = StandardSampler::default(); Ok(Box::new( - CompletionStream::new_oneshot(session, &args.prompt, model_signal, sampler).await?, + CompletionStream::new_oneshot(session, &prompt, model_signal, sampler).await?, )) } else { - let (session, id, new_context) = self.take_chat_session(&args.prompt).await; + let (session, id, new_context) = self.take_chat_session(&prompt).await; let sampler = StandardSampler::default(); let tx = self.finished_tx.clone(); diff --git a/crates/edgen_rt_whisper_cpp/Cargo.toml b/crates/edgen_rt_whisper_cpp/Cargo.toml index 530efe0..9a10899 100644 --- a/crates/edgen_rt_whisper_cpp/Cargo.toml +++ b/crates/edgen_rt_whisper_cpp/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = { workspace = true } dashmap = { workspace = true } edgen_core = { path = "../edgen_core" } futures = { workspace = true } diff --git a/crates/edgen_rt_whisper_cpp/src/lib.rs b/crates/edgen_rt_whisper_cpp/src/lib.rs index f6960aa..bb303dc 100644 --- a/crates/edgen_rt_whisper_cpp/src/lib.rs +++ b/crates/edgen_rt_whisper_cpp/src/lib.rs @@ -22,13 +22,13 @@ use tracing::info; use uuid::Uuid; use whisper_cpp::{WhisperModel, WhisperParams, WhisperSampling, WhisperSession}; +use edgen_core::cleanup_interval; use edgen_core::perishable::{ActiveSignal, Perishable, PerishableReadGuard, PerishableWriteGuard}; use edgen_core::settings::{DevicePolicy, SETTINGS}; use edgen_core::whisper::{ inactive_whisper_session_ttl, inactive_whisper_ttl, parse, TranscriptionArgs, WhisperEndpoint, WhisperEndpointError, }; -use edgen_core::{cleanup_interval, BoxedFuture}; /// A large language model endpoint, implementing [`WhisperEndpoint`] using a [`whisper_cpp`] backend. pub struct WhisperCppEndpoint { @@ -57,12 +57,13 @@ impl WhisperCppEndpoint { // PANIC SAFETY: Just inserted the element if it isn't already inside the map, so must be present in the map self.models.get(&key).unwrap() } +} - /// Helper `async` function that returns the transcription for the specified model and - /// [`TranscriptionArgs`] - async fn async_transcription( +#[async_trait::async_trait] +impl WhisperEndpoint for WhisperCppEndpoint { + async fn transcription( &self, - model_path: impl AsRef, + model_path: impl AsRef + Send, args: TranscriptionArgs, ) -> Result<(String, Option), WhisperEndpointError> { let pcm = parse::pcm(&args.file)?; @@ -71,17 +72,6 @@ impl WhisperCppEndpoint { .transcription(args.create_session, args.session, pcm) .await } -} - -impl WhisperEndpoint for WhisperCppEndpoint { - fn transcription<'a>( - &'a self, - model_path: impl AsRef + Send + 'a, - args: TranscriptionArgs, - ) -> BoxedFuture), WhisperEndpointError>> { - let pinned = Box::pin(self.async_transcription(model_path, args)); - Box::new(pinned) - } fn reset(&self) { self.models.clear(); diff --git a/crates/edgen_server/Cargo.toml b/crates/edgen_server/Cargo.toml index c9a9f8c..17ea3d5 100644 --- a/crates/edgen_server/Cargo.toml +++ b/crates/edgen_server/Cargo.toml @@ -23,7 +23,9 @@ hyper = { workspace = true } hyper-util = { workspace = true } once_cell = { workspace = true } pin-project = { workspace = true } -reqwest = { workspace = true, features = ["blocking", "multipart"] } +rand = "0.8.5" +reqwest = { workspace = true, features = ["blocking", "multipart", "json"] } +reqwest-eventsource = "0.6.0" rubato = "0.15.0" serde = { workspace = true } serde_derive = { workspace = true } @@ -54,3 +56,8 @@ llama_cuda = ["edgen_rt_llama_cpp/cuda"] llama_metal = ["edgen_rt_llama_cpp/metal"] whisper_cuda = ["edgen_rt_whisper_cpp/cuda"] candle_cuda = ["edgen_rt_image_generation_candle/cuda"] + +[[bin]] +name = "chatter" +test = false +bench = false diff --git a/crates/edgen_server/src/bin/chatter.rs b/crates/edgen_server/src/bin/chatter.rs new file mode 100644 index 0000000..48248af --- /dev/null +++ b/crates/edgen_server/src/bin/chatter.rs @@ -0,0 +1,363 @@ +use std::borrow::Cow; +use std::path::PathBuf; +use std::time::Duration; + +use either::Either; +use futures::StreamExt; +use rand::Rng; +use reqwest_eventsource::{retry, Event}; +use reqwest_eventsource::{Error, EventSource}; +use tokio::io::AsyncWriteExt; +use tokio::sync::mpsc; +use tokio::task::JoinSet; +use tokio::time::{sleep, Instant}; +use tracing::{debug, error, info}; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +use edgen_server::openai_shim::{ChatCompletionChunk, ChatMessage, CreateChatCompletionRequest}; + +const START_PROMPTS: [&str; 6] = [ + "Hello!", + "Please give me a number between 1 and 50.", + "Please tell me a short story.", + "Please tell me a long story.", + "What is the capital of Portugal?", + "What is the current weather like in France?", +]; + +const CONTINUE_PROMPTS: [&str; 4] = [ + "Please continue.", + "Tell me more.", + "Can you give me more details?", + "I don't understand.", +]; + +const LARGE_CONTEXT: &str = r#"Gordon Freeman, a recently employed theoretical physicist, is involved in an experiment analyzing an unknown crystalline artifact; however, when the anti-mass spectrometer beam contacts the crystal, it creates a resonance cascade that opens a dimensional rift between Black Mesa and another world called Xen, causing monsters to swarm Black Mesa and kill many of the facility's personnel. Attempts by the Black Mesa personnel to close the rift are unsuccessful, leading to a Marine Recon unit being sent in to silence the facility, including any survivors from the science team. Freeman fights through the facility to meet with several other scientists, who decide to travel to the alien dimension to stop the aliens. On Xen, Freeman eliminates the alien "leader" and is confronted by the G-Man, who offers Freeman employment before putting him into stasis.[2] Back in Black Mesa, a second alien race begins an invasion, but is stopped when a Marine corporal, Adrian Shephard, collapses its portal in the facility. The G-Man then destroys Black Mesa with a nuclear warhead, and detains Shephard in stasis. Barney Calhoun, a security officer, also escaped from the facility with Dr. Rosenberg and two other scientists. Nearly twenty years later,[2] Half-Life 2 opens as the G-Man brings Freeman out of stasis and inserts him into a dystopian Earth ruled by the Combine, a faction consisting of human and alien members, that used the dimensional rift caused at Black Mesa to conquer Earth in the interim. In the Eastern European settlement City 17, Freeman meets surviving members of the Black Mesa incident, including Isaac Kleiner, Barney Calhoun, Eli Vance and his daughter Alyx Vance, and aids in the human resistance against Combine rule. The Xen aliens, the Vortigaunts, who have been enslaved by the Combine, also assist the resistance. When his presence is made known to former Black Mesa administrator and Combine spokesman Wallace Breen, Freeman becomes a prime target for the Combine forces. Eventually, Freeman sparks a full revolution amongst the human citizens after destroying Nova Prospekt, a major Combine base and troop-production facility. Eli Vance and his daughter are subsequently captured by the Combine, and Freeman helps the resistance forces attack the Combine's Citadel to rescue them, fighting alongside Barney. Freeman fights his way through the Citadel, making his way to Breen's office. He is temporarily captured, but freed by Dr. Mossman, along with Eli and Alyx. Breen attempts to flee in a teleporter, but is presumed dead after Freeman destroys the dark energy reactor at the Citadel's top. The story continues with Half-Life 2: Episode One, as the G-Man then arrives to extract Freeman before he is engulfed in the explosion, but is interrupted when Vortigaunts liberate Freeman from stasis and place both him and Alyx Vance at the bottom of the Citadel. Alyx then contacts her father, Eli Vance, and Isaac Kleiner, who have escaped the city into the surrounding countryside. Kleiner informs them that the reactor's core has gone critical due to the destruction of the dark energy reaction, and is at risk of exploding at any moment, an explosion which could completely destroy City 17. To delay the explosion they must enter the Citadel's now-decaying core and attempt to stabilize its primary reactor while the citizens evacuate the city from a train station. While inside, they discover that the Combine are attempting to speed up the destruction of the reactor, and use the destruction of the Citadel to call for reinforcements from the Combine's native dimension. After downloading critical data, they move through the war-torn city to the train station to take the last train out of the city. The Combine then destroy the reactor and thus both the Citadel and the city; the resulting explosion causes the train to derail. Half-Life 2: Episode Two begins as Freeman awakens in one of the wrecked train cars with Alyx outside. In the distance a forming super-portal is visible where the Citadel used to stand. They begin a journey through the White Forest to a resistance-controlled missile base in the nearby mountains. Along the way, Freeman and Alyx are ambushed and Alyx is severely injured. However, a group of Vortigaunts are able to heal her. During the healing ritual, Freeman receives word from G-Man, indicating that the Vortigaunts were keeping him at bay. G-Man demands that Freeman take Alyx to White Forest as safely as possible, saying that he cannot help as per restrictions he has agreed to. They are able to reach the resistance base and deliver the data, which contains the codes to destroy the portal as well as information on the Borealis, an enigmatic research vessel operated by Black Mesa's rival, Aperture Science; however, the ship disappeared while testing portal technology. The base then launches a satellite that is able to shut down the super-portal, cutting off the Combine from outside assistance. However, as Alyx and Freeman prepare to travel to the Arctic and investigate the Borealis, they are attacked by Combine Advisors, who kill Eli Vance, before being driven off by Alyx's pet robot, D0g."#; + +const LARGE_PROMPTS: [&str; 5] = [ + "Please resume the Half-Life story.", + "Please give a summary of the Half-Life story.", + "Do you think Gordon's actions were correct?", + "What was Alyx's pet robot called?", + "Please write a story similar to Half-Life.", +]; + +/// Send an arbitrary number of requests to the streaming chat endpoint. +#[derive(argh::FromArgs, PartialEq, Debug, Clone)] +pub struct Chat { + /// the total amount of requests sent. + #[argh(positional, default = "10")] + pub requests: usize, + + /// the base chance that a conversation will continue. + #[argh(option, short = 'b', default = "0.6")] + pub continue_chance: f32, + + /// how much the chance to continue a conversation will decrease with each successive message. + #[argh(option, short = 'd', default = "0.05")] + pub chance_decay: f32, + + /// the minimum amount of time to wait before a request is sent. + #[argh(option, short = 'i', default = "3.0")] + pub min_idle: f32, + + /// the maximum amount of time to wait before a request is sent. + #[argh(option, short = 'a', default = "10.0")] + pub max_idle: f32, + + /// the maximum size of a received message. + #[argh(option, short = 'l', default = "1000")] + pub message_limit: usize, + + /// the chance that a request will start with large context. + #[argh(option, short = 'e', default = "0.0")] + pub large_chance: f32, + + /// the base URL of the endpoint the requests will be sent to. + #[argh( + option, + short = 'u', + default = "String::from(\"http://127.0.0.1:33322\")" + )] + pub url: String, +} + +#[tokio::main] +async fn main() { + let format = tracing_subscriber::fmt::layer().compact(); + let filter = tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or( + tracing_subscriber::EnvFilter::default() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ); + tracing_subscriber::registry() + .with(format) + .with(filter) + .init(); + + let chat_args: Chat = argh::from_env(); + + assert!( + chat_args.min_idle <= chat_args.max_idle, + "Minimum idle time cannot be higher than the maximum" + ); + + let mut rng = rand::thread_rng(); + + let mut request_chains = vec![]; + let mut chain: usize = 0; + for _ in 0..chat_args.requests { + let chance = f32::max( + chat_args.continue_chance - chat_args.chance_decay * chain as f32, + 0.0, + ); + + chain += 1; + if chance < rng.gen() { + request_chains.push(chain); + chain = 0; + } + } + + if chain > 0 { + request_chains.push(chain); + } + + let mut join_set = JoinSet::new(); + let (tx, mut rx) = mpsc::unbounded_channel(); + for (id, count) in request_chains.drain(..).enumerate() { + join_set.spawn(chain_requests(chat_args.clone(), count, id, tx.clone())); + } + drop(tx); + + let mut first_tokens = vec![]; + let mut all_tokens = vec![]; + let mut all_tokens_nf = vec![]; + let mut token_counts = vec![]; + + let date_time = time::OffsetDateTime::now_utc().to_string(); + let fmt_time = date_time[..date_time.len() - 18] + .replace(' ', "_") + .replace(':', "-"); + let file_name = format!( + "n{}_b{:.3}_d{:.3}_i{:.3}_a{:.3}_l{}_e{:.3}_{}", + chat_args.requests, + chat_args.continue_chance, + chat_args.chance_decay, + chat_args.min_idle, + chat_args.max_idle, + chat_args.message_limit, + chat_args.large_chance, + fmt_time + ); + let file_name = file_name.replace('.', "-"); + let file_path = format!("out/{file_name}.csv"); + if !PathBuf::from("out").exists() { + tokio::fs::create_dir("out") + .await + .expect("Failed to create output directory"); + } + let mut f = tokio::fs::File::create(&file_path).await.unwrap(); + while let Some(stats) = rx.recv().await { + f.write_all(format!("{}\n", stats.to_row(",")).as_bytes()) + .await + .expect("Failed to write to file"); + first_tokens.push(stats.first_token); + all_tokens.extend(&stats.all_tokens); + all_tokens_nf.extend(&stats.all_tokens[1..]); + token_counts.push(stats.all_tokens.len()); + } + f.flush().await.expect("Failed to flush file"); + println!("Wrote output to file: \"{file_path}\""); + + println!("First token times:"); + print_stats(first_tokens); + println!("All token times:"); + print_stats(all_tokens); + println!("All token times (without first token):"); + print_stats(all_tokens_nf); + println!("Token counts:"); + print_token_stats(token_counts); + + while let Some(_) = join_set.join_next().await {} +} + +async fn chain_requests( + chat_args: Chat, + count: usize, + index: usize, + stats_tx: mpsc::UnboundedSender, +) { + let client = reqwest::Client::new(); + let base_builder = client.post(chat_args.url + "/v1/chat/completions"); + let mut body = CreateChatCompletionRequest { + messages: Default::default(), + model: Cow::from("default"), + frequency_penalty: None, + logit_bias: None, + max_tokens: Some(chat_args.message_limit as u32), + n: None, + presence_penalty: None, + seed: None, + stop: None, + stream: Some(true), + response_format: None, + temperature: None, + top_p: None, + tools: None, + tool_choice: None, + user: None, + one_shot: None, + context_hint: None, + }; + + body.messages.push(ChatMessage::System { + content: Some(Cow::from("You are Edgen, a helpful assistant.")), + name: None, + }); + + if chat_args.large_chance < rand::thread_rng().gen() { + let prompt_idx = rand::thread_rng().gen_range(0..START_PROMPTS.len()); + body.messages.push(ChatMessage::User { + content: Either::Left(Cow::from(START_PROMPTS[prompt_idx])), + name: None, + }); + } else { + body.messages.push(ChatMessage::System { + content: Some(Cow::from(LARGE_CONTEXT)), + name: None, + }); + + let prompt_idx = rand::thread_rng().gen_range(0..LARGE_PROMPTS.len()); + body.messages.push(ChatMessage::User { + content: Either::Left(Cow::from(LARGE_PROMPTS[prompt_idx])), + name: None, + }); + } + + for request in 0..count { + let wait = if chat_args.min_idle != chat_args.max_idle { + rand::thread_rng().gen_range(chat_args.min_idle..chat_args.max_idle) + } else { + chat_args.min_idle + }; + sleep(Duration::from_secs_f32(wait)).await; + info!( + "Chain {} sending request {} of {}.", + index, + request + 1, + count + ); + + let builder = base_builder.try_clone().unwrap().json(&body); + + let mut stats = RequestStatistics { + first_token: -1.0, + all_tokens: vec![], + }; + let mut t = Instant::now(); + + let mut event_source = EventSource::new(builder).unwrap(); + event_source.set_retry_policy(Box::new(retry::Never)); + let mut token_count = 0; + let mut text = "".to_string(); + while let Some(event) = event_source.next().await { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(message)) => { + if token_count >= chat_args.message_limit { + event_source.close(); + break; + } + + let nt = Instant::now(); + let d = (nt - t).as_secs_f32(); + t = nt; + + if stats.first_token == -1.0 { + stats.first_token = d; + } + stats.all_tokens.push(d); + + token_count += 1; + debug!("Chain {index} has received token {token_count}"); + let response: ChatCompletionChunk = + serde_json::from_str(message.data.as_str()).unwrap(); + text += response.choices[0].delta.content.as_ref().unwrap(); + } + Err(reqwest_eventsource::Error::StreamEnded) => {} + Err(err) => { + match err { + // Error::Utf8(_) => {} + // Error::Parser(_) => {} + // Error::Transport(_) => {} + // Error::InvalidContentType(_, _) => {} + Error::InvalidStatusCode(code, response) => { + error!("Error {}: {}", code, response.text().await.unwrap()); + } + // Error::InvalidLastEventId(_) => {} + Error::StreamEnded => {} + _ => println!("Error: {}", err), + } + event_source.close(); + } + } + } + + if stats.all_tokens.len() != 0 { + stats_tx.send(stats).unwrap(); + } + + body.messages.push(ChatMessage::Assistant { + content: Some(Cow::from(text)), + name: None, + tool_calls: None, + }); + + let continue_idx = rand::thread_rng().gen_range(0..CONTINUE_PROMPTS.len()); + body.messages.push(ChatMessage::User { + content: Either::Left(Cow::from(CONTINUE_PROMPTS[continue_idx])), + name: None, + }); + } + + info!("Chain {index} finished") +} + +struct RequestStatistics { + first_token: f32, + all_tokens: Vec, +} + +impl RequestStatistics { + fn to_row(&self, delimiter: &str) -> String { + let mut res = self.first_token.to_string(); + for token in &self.all_tokens[1..] { + res += delimiter; + res += &token.to_string(); + } + res + } +} + +fn print_stats(mut values: Vec) { + let mean = values.iter().map(|v| *v).reduce(|a, b| a + b).unwrap() / values.len() as f32; + values.sort_unstable_by(|a, b| a.total_cmp(b)); + let min = values[0]; + let max = *values.last().unwrap(); + let median = values[values.len() / 2]; + + println!("Mean: {mean}s ; Median: {median}s ; Min: {min}s ; Max: {max}s"); +} + +fn print_token_stats(mut values: Vec) { + let mean = values.iter().map(|v| *v).reduce(|a, b| a + b).unwrap() / values.len(); + values.sort_unstable_by(|a, b| a.cmp(b)); + let min = values[0]; + let max = *values.last().unwrap(); + let median = values[values.len() / 2]; + + println!( + "Mean: {mean} tokens ; Median: {median} tokens ; Min: {min} tokens ; Max: {max} tokens" + ); +} diff --git a/crates/edgen_server/src/image_generation.rs b/crates/edgen_server/src/image_generation.rs index 8ada353..4590093 100644 --- a/crates/edgen_server/src/image_generation.rs +++ b/crates/edgen_server/src/image_generation.rs @@ -1,17 +1,30 @@ -use crate::audio::ChatCompletionError; -use crate::model_descriptor::{ModelDescriptor, ModelDescriptorError, ModelPaths, Quantization}; +use crate::model_descriptor::{ + ModelDescriptor, ModelDescriptorError, ModelPaths, Quantization, StableDiffusionFiles, +}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; +use dashmap::DashMap; use edgen_core::image_generation::{ ImageGenerationArgs, ImageGenerationEndpoint, ImageGenerationEndpointError, ModelFiles, }; use edgen_rt_image_generation_candle::CandleImageGenerationEndpoint; +use either::Either; use serde_derive::{Deserialize, Serialize}; use std::borrow::Cow; use thiserror::Error; use utoipa::ToSchema; +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Model<'a> { + unet_weights: Cow<'a, str>, + vae_weights: Cow<'a, str>, + clip_weights: Cow<'a, str>, + /// Beware that not all models support clip2. + clip2_weights: Option>, + tokenizer: Cow<'a, str>, +} + /// A request to generate images for the provided context. /// This request is not at all conformant with OpenAI's API, as that one is very bare-bones, lacking /// in many parameters that we need. @@ -25,7 +38,7 @@ pub struct CreateImageGenerationRequest<'a> { pub prompt: Cow<'a, str>, /// The model to use for generating completions. - pub model: Cow<'a, str>, + pub model: Either, Model<'a>>, /// The width of the generated image. pub width: Option, @@ -44,7 +57,7 @@ pub struct CreateImageGenerationRequest<'a> { /// Default: 1 pub images: Option, - /// The random number generator seed to used for the generation. + /// The random number generator seed to use for the generation. /// /// By default, a random seed is used. pub seed: Option, @@ -57,7 +70,9 @@ pub struct CreateImageGenerationRequest<'a> { /// The Variational Auto-Encoder scale to use for generation. /// - /// This value should probably not be set. + /// Required if `model` is not a pre-made descriptor name. + /// + /// This value should probably not be set, if `model` is a pre-made descriptor name. pub vae_scale: Option, } @@ -75,14 +90,17 @@ pub struct ImageGenerationResponse { #[serde(tag = "error")] pub enum ImageGenerationError { /// The provided model could not be loaded. - #[error("failed to load model: {0}")] + #[error(transparent)] Model(#[from] ModelDescriptorError), - /// Some error has occured inside the endpoint. - #[error("endpoint error: {0}")] + /// Some error has occurred inside the endpoint. + #[error(transparent)] Endpoint(#[from] ImageGenerationEndpointError), /// This error should be unreachable. #[error("Something went wrong")] Unreachable, + /// Some parameter was missing from the request. + #[error("A parameter was missing from the request: {0}")] + MissingParam(String), } impl IntoResponse for ImageGenerationError { @@ -98,7 +116,7 @@ impl IntoResponse for ImageGenerationError { /// cannot do. /// /// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`ImageGenerationError`] -/// to the peer.. +/// to the peer. #[utoipa::path( post, path = "/image/generations", @@ -111,13 +129,45 @@ responses( pub async fn generate_image( Json(req): Json>, ) -> Result { - let descriptor = crate::model_descriptor::get(req.model.as_ref())?; + let quantization; + let descriptor = match req.model { + Either::Left(template) => { + quantization = Quantization::F16; + crate::model_descriptor::get(template.as_ref())? + .value() + .clone() // Not ideal to clone, but otherwise the code complexity will greatly increase + } + Either::Right(custom) => { + if req.vae_scale.is_none() { + return Err(ImageGenerationError::MissingParam( + "VAE scale must be provided when manually specifying model files".to_string(), + )); + } + quantization = Quantization::Default; + let files = DashMap::new(); + files.insert( + quantization, + StableDiffusionFiles { + tokenizer: custom.tokenizer.to_string(), + clip_weights: custom.clip_weights.to_string(), + clip2_weights: custom.clip2_weights.map(|c| c.to_string()), + vae_weights: custom.vae_weights.to_string(), + unet_weights: custom.unet_weights.to_string(), + }, + ); + ModelDescriptor::StableDiffusion { + files, + steps: 30, + vae_scale: req.vae_scale.unwrap(), + } + } + }; let model_files; let default_steps; let default_vae_scale; if let ModelDescriptor::StableDiffusion { steps, vae_scale, .. - } = descriptor.value() + } = descriptor { if let ModelPaths::StableDiffusion { unet_weights, @@ -125,7 +175,7 @@ pub async fn generate_image( clip_weights, clip2_weights, tokenizer, - } = descriptor.preload_files(Quantization::F16).await? + } = descriptor.preload_files(quantization).await? { model_files = ModelFiles { tokenizer, @@ -152,11 +202,11 @@ pub async fn generate_image( uncond_prompt: req.uncond_prompt.unwrap_or(Cow::from("")).to_string(), width: req.width, height: req.height, - steps: req.steps.unwrap_or(*default_steps), + steps: req.steps.unwrap_or(default_steps), images: req.images.unwrap_or(1), seed: req.seed, guidance_scale: req.guidance_scale.unwrap_or(7.5), - vae_scale: req.vae_scale.unwrap_or(*default_vae_scale), + vae_scale: req.vae_scale.unwrap_or(default_vae_scale), }, ) .await?; diff --git a/crates/edgen_server/src/model.rs b/crates/edgen_server/src/model.rs index 3e66dc3..1cf31b8 100644 --- a/crates/edgen_server/src/model.rs +++ b/crates/edgen_server/src/model.rs @@ -49,7 +49,7 @@ pub enum ModelKind { LLM, Whisper, ChatFaker, - ImageDiffusion, + StableDiffusion, } #[derive(Debug, PartialEq)] diff --git a/crates/edgen_server/src/model_descriptor.rs b/crates/edgen_server/src/model_descriptor.rs index 7ed3116..a61dbdf 100644 --- a/crates/edgen_server/src/model_descriptor.rs +++ b/crates/edgen_server/src/model_descriptor.rs @@ -1,4 +1,5 @@ use crate::model::{Model, ModelError, ModelKind}; +use crate::openai_shim::{parse_model_param, ParseError}; use crate::types::Endpoint; use dashmap::DashMap; use edgen_core::settings; @@ -17,10 +18,13 @@ pub enum ModelDescriptorError { Preload(#[from] ModelError), #[error("The specified model was not found")] NotFound, + #[error(transparent)] + Parse(#[from] ParseError), } /// The descriptor of an artificial intelligence model, containing every bit of data required to /// execute the model. +#[derive(Clone)] pub enum ModelDescriptor { /// A stable diffusion model. StableDiffusion { @@ -35,17 +39,13 @@ pub enum ModelDescriptor { }, } +#[derive(Clone)] pub struct StableDiffusionFiles { - unet_weights_repo: String, - unet_weights_file: String, - vae_weights_repo: String, - vae_weights_file: String, - clip_weights_repo: String, - clip_weights_file: String, - clip2_weights_repo: Option, - clip2_weights_file: Option, - tokenizer_repo: String, - tokenizer_file: String, + pub unet_weights: String, + pub vae_weights: String, + pub clip_weights: String, + pub clip2_weights: Option, + pub tokenizer: String, } #[derive(Copy, Clone, Hash, Eq, PartialEq)] @@ -65,6 +65,30 @@ pub enum ModelPaths { } impl ModelDescriptor { + async fn get_file(&self, file_link: &str) -> Result { + let (dir, kind) = match self { + ModelDescriptor::StableDiffusion { .. } => ( + PathBuf::from(settings::image_generation_dir().await), + ModelKind::StableDiffusion, + ), + }; + + let path = PathBuf::from(file_link); + if path.is_file() { + return Ok(path); + } + + let path = dir.join(file_link); + if path.is_file() { + return Ok(path); + } + + let (owner, repo, name) = parse_model_param(file_link)?; + let mut file = Model::new(kind, &name, &format!("{owner}/{repo}"), &dir); + file.preload(Endpoint::ImageGeneration).await?; + Ok(file.file_path()?) + } + pub async fn preload_files( &self, quantization: Quantization, @@ -75,60 +99,17 @@ impl ModelDescriptor { if files.is_none() { return Err(ModelDescriptorError::QuantizationUnavailable); } + let files = files.unwrap(); - let dir = PathBuf::from(settings::image_generation_dir().await); - let unet = { - let mut unet = Model::new( - ModelKind::ImageDiffusion, - &files.unet_weights_file, - &files.unet_weights_repo, - &dir, - ); - unet.preload(Endpoint::ImageGeneration).await?; - unet.file_path()? - }; - let vae = { - let mut vae = Model::new( - ModelKind::ImageDiffusion, - &files.vae_weights_file, - &files.vae_weights_repo, - &dir, - ); - vae.preload(Endpoint::ImageGeneration).await?; - vae.file_path()? - }; - let clip = { - let mut clip = Model::new( - ModelKind::ImageDiffusion, - &files.clip_weights_file, - &files.clip_weights_repo, - &dir, - ); - clip.preload(Endpoint::ImageGeneration).await?; - clip.file_path()? - }; - let clip2 = if files.clip2_weights_file.is_some() { - let mut clip2 = Model::new( - ModelKind::ImageDiffusion, - files.clip2_weights_file.as_ref().unwrap(), - files.clip2_weights_repo.as_ref().unwrap(), - &dir, - ); - clip2.preload(Endpoint::ImageGeneration).await?; - Some(clip2.file_path()?) + let unet = self.get_file(&files.unet_weights).await?; + let vae = self.get_file(&files.vae_weights).await?; + let clip = self.get_file(&files.clip_weights).await?; + let clip2 = if let Some(clip2) = &files.clip2_weights { + Some(self.get_file(&clip2).await?) } else { None }; - let tokenizer = { - let mut tokenizer = Model::new( - ModelKind::ImageDiffusion, - &files.tokenizer_file, - &files.tokenizer_repo, - &dir, - ); - tokenizer.preload(Endpoint::ImageGeneration).await?; - tokenizer.file_path()? - }; + let tokenizer = self.get_file(&files.tokenizer).await?; ModelPaths::StableDiffusion { unet_weights: unet, @@ -149,31 +130,30 @@ pub fn init() { model_files.insert( Quantization::Default, StableDiffusionFiles { - tokenizer_repo: "openai/clip-vit-base-patch32".to_string(), - tokenizer_file: "tokenizer.json".to_string(), - clip_weights_repo: "stabilityai/stable-diffusion-2-1".to_string(), - clip_weights_file: "text_encoder/model.safetensors".to_string(), - clip2_weights_repo: None, - clip2_weights_file: None, - vae_weights_repo: "stabilityai/stable-diffusion-2-1".to_string(), - vae_weights_file: "vae/diffusion_pytorch_model.safetensors".to_string(), - unet_weights_repo: "stabilityai/stable-diffusion-2-1".to_string(), - unet_weights_file: "unet/diffusion_pytorch_model.safetensors".to_string(), + tokenizer: "openai/clip-vit-base-patch32/tokenizer.json".to_string(), + clip_weights: "stabilityai/stable-diffusion-2-1/text_encoder/model.safetensors" + .to_string(), + clip2_weights: None, + vae_weights: "stabilityai/stable-diffusion-2-1/vae/diffusion_pytorch_model.safetensors" + .to_string(), + unet_weights: + "stabilityai/stable-diffusion-2-1/unet/diffusion_pytorch_model.safetensors" + .to_string(), }, ); model_files.insert( Quantization::F16, StableDiffusionFiles { - tokenizer_repo: "openai/clip-vit-base-patch32".to_string(), - tokenizer_file: "tokenizer.json".to_string(), - clip_weights_repo: "stabilityai/stable-diffusion-2-1".to_string(), - clip_weights_file: "text_encoder/model.fp16.safetensors".to_string(), - clip2_weights_repo: None, - clip2_weights_file: None, - vae_weights_repo: "stabilityai/stable-diffusion-2-1".to_string(), - vae_weights_file: "vae/diffusion_pytorch_model.fp16.safetensors".to_string(), - unet_weights_repo: "stabilityai/stable-diffusion-2-1".to_string(), - unet_weights_file: "unet/diffusion_pytorch_model.fp16.safetensors".to_string(), + tokenizer: "openai/clip-vit-base-patch32/tokenizer.json".to_string(), + clip_weights: "stabilityai/stable-diffusion-2-1/text_encoder/model.fp16.safetensors" + .to_string(), + clip2_weights: None, + vae_weights: + "stabilityai/stable-diffusion-2-1/vae/diffusion_pytorch_model.fp16.safetensors" + .to_string(), + unet_weights: + "stabilityai/stable-diffusion-2-1/unet/diffusion_pytorch_model.fp16.safetensors" + .to_string(), }, ); let model = ModelDescriptor::StableDiffusion { diff --git a/crates/edgen_server/src/openai_shim.rs b/crates/edgen_server/src/openai_shim.rs index b0a8e02..8b15a14 100644 --- a/crates/edgen_server/src/openai_shim.rs +++ b/crates/edgen_server/src/openai_shim.rs @@ -17,7 +17,6 @@ use std::borrow::Cow; use std::collections::HashMap; -use std::fmt::{Display, Formatter}; use std::path::PathBuf; use axum::http::StatusCode; @@ -72,21 +71,6 @@ pub enum ContentPart<'a> { }, } -impl<'a> Display for ContentPart<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ContentPart::Text { text } => write!(f, "{}", text), - ContentPart::ImageUrl { url, detail } => { - if let Some(detail) = detail { - write!(f, " ({})", url, detail) - } else { - write!(f, "", url) - } - } - } - } -} - /// A description of a function provided to a large language model, to assist it in interacting /// with the outside world. /// @@ -233,52 +217,6 @@ pub struct ChatMessages<'a>( Vec>, ); -impl<'a> Display for ChatMessages<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - for message in &self.0 { - match message { - ChatMessage::System { - content: Some(data), - .. - } => { - write!(f, "<|SYSTEM|>{data}")?; - } - ChatMessage::User { - content: Either::Left(data), - .. - } => { - write!(f, "<|USER|>{data}")?; - } - ChatMessage::User { - content: Either::Right(data), - .. - } => { - write!(f, "<|USER|>")?; - - for part in data { - write!(f, "{part}")?; - } - } - ChatMessage::Assistant { - content: Some(data), - .. - } => { - write!(f, "<|ASSISTANT|>{data}")?; - } - ChatMessage::Tool { - content: Some(data), - .. - } => { - write!(f, "<|TOOL|>{data}")?; - } - _ => {} - } - } - - Ok(()) - } -} - /// A request to generate chat completions for the provided context. /// /// An `axum` handler, [`chat_completions`][chat_completions], is provided to handle this request. @@ -647,7 +585,7 @@ fn get_model_params(model_name: &str, dir: &str) -> Result Result<(String, String, String), ParseError> { +pub(crate) fn parse_model_param(model: &str) -> Result<(String, String, String), ParseError> { let vs = model.split("/").collect::>(); let l = vs.len(); if l < 3 { @@ -674,18 +612,125 @@ fn parse_model_param(model: &str) -> Result<(String, String, String), ParseError Ok((owner, repo, name)) } +impl From> for edgen_core::llm::ContentPart { + fn from(value: ContentPart) -> Self { + match value { + ContentPart::Text { text } => Self::Text { + text: text.to_string(), + }, + ContentPart::ImageUrl { url, detail } => Self::ImageUrl { + url: url.to_string(), + detail: detail.map(|x| x.to_string()), + }, + } + } +} + +impl From> for edgen_core::llm::AssistantToolCall { + fn from(value: AssistantToolCall) -> Self { + Self { + id: value.id.to_string(), + type_: value.type_.to_string(), + function: edgen_core::llm::AssistantFunctionStub { + name: value.function.name.to_string(), + arguments: value.function.arguments.to_string(), + }, + } + } +} + +impl From> for edgen_core::llm::ChatMessage { + fn from(value: ChatMessage) -> Self { + match value { + ChatMessage::System { content, name } => Self::System { + content: content.map(|x| x.to_string()), + name: name.map(|x| x.to_string()), + }, + ChatMessage::User { content, name } => Self::User { + content: match content { + Either::Left(text) => Either::Left(text.to_string()), + Either::Right(mut msgs) => Either::Right( + msgs.drain(..) + .map(|x| edgen_core::llm::ContentPart::from(x)) + .collect(), + ), + }, + name: name.map(|x| x.to_string()), + }, + ChatMessage::Assistant { + content, + name, + tool_calls, + } => Self::Assistant { + content: content.map(|x| x.to_string()), + name: name.map(|x| x.to_string()), + tool_calls: tool_calls.map(|mut o| { + o.drain(..) + .map(|x| edgen_core::llm::AssistantToolCall::from(x)) + .collect() + }), + }, + ChatMessage::Tool { + content, + tool_call_id, + } => Self::Tool { + content: content.map(|x| x.to_string()), + tool_call_id: tool_call_id.to_string(), + }, + } + } +} + +impl From> for edgen_core::llm::ChatMessages { + fn from(mut value: ChatMessages) -> Self { + Self( + value + .drain(..) + .map(|x| edgen_core::llm::ChatMessage::from(x)) + .collect(), + ) + } +} + +impl From> for CompletionArgs { + fn from(value: CreateChatCompletionRequest) -> Self { + Self { + messages: value.messages.into(), + frequency_penalty: value.frequency_penalty, + logit_bias: value.logit_bias, + max_tokens: value.max_tokens, + n: value.n, + presence_penalty: value.presence_penalty, + seed: value.seed, + stop: value.stop.map(|x| match x { + Either::Left(text) => Either::Left(text.to_string()), + Either::Right(mut v) => Either::Right(v.drain(..).map(|x| x.to_string()).collect()), + }), + temperature: value.temperature, + top_p: value.top_p, + one_shot: value.one_shot, + context_hint: value.context_hint, + } + } +} + /// Error Parsing the model parameter -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Error, Serialize)] pub enum ParseError { /// Expected are three fields separated by '/'; fewer fields were provided. + #[error("Expected are three fields separated by '/'; fewer fields were provided")] MissingSeparator, /// Expected are three fields separated by '/'; more than three fields were provided. + #[error("Expected are three fields separated by '/'; more than three fields were provided")] TooManySeparators, /// No model name was provided. + #[error("No model name was provided")] NoModel, /// No repo owner was provided. + #[error("No repo owner was provided")] NoOwner, /// No repo was provided. + #[error("No repo was provided")] NoRepo, } @@ -766,31 +811,16 @@ pub async fn chat_completions( model_name: params.name.to_string(), })?; - let untokenized_context = format!("{}<|ASSISTANT|>", req.messages); - - let mut args = CompletionArgs { - prompt: untokenized_context, - seed: req.seed, - context_hint: req.context_hint, - ..Default::default() - }; - - if let Some(one_shot) = req.one_shot { - args.one_shot = one_shot; - } - - if let Some(frequency_penalty) = req.frequency_penalty { - args.frequency_penalty = frequency_penalty; - } - let stream_response = req.stream.unwrap_or(false); let fp = format!("edgen-{}", cargo_crate_version!()); let response = if stream_response { let completions_stream = { let result = match model.kind { - ModelKind::LLM => llm::chat_completion_stream(model, args).await?, - ModelKind::ChatFaker => chat_faker::chat_completion_stream(model, args).await?, + ModelKind::LLM => llm::chat_completion_stream(model, req.into()).await?, + ModelKind::ChatFaker => { + chat_faker::chat_completion_stream(model, req.into()).await? + } _ => panic!("we should never get here"), }; result.map(move |chunk| { @@ -814,8 +844,8 @@ pub async fn chat_completions( ChatCompletionResponse::Stream(Sse::new(completions_stream)) } else { let content_str = match model.kind { - ModelKind::LLM => llm::chat_completion(model, args).await?, - ModelKind::ChatFaker => crate::chat_faker::chat_completion(model, args).await?, + ModelKind::LLM => llm::chat_completion(model, req.into()).await?, + ModelKind::ChatFaker => crate::chat_faker::chat_completion(model, req.into()).await?, _ => panic!("we should never get here"), }; let response = ChatCompletion { diff --git a/docs/src/app/api-reference/audio/page.mdx b/docs/src/app/api-reference/audio/page.mdx index a6d37a1..2edbd37 100644 --- a/docs/src/app/api-reference/audio/page.mdx +++ b/docs/src/app/api-reference/audio/page.mdx @@ -43,6 +43,31 @@ Discover how to convert audio to text or text to audio. OpenAI compliant. {{ cla ### Optional attributes + + + The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency. + + + + + + An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language. + + + + + + The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. + + + + + + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit. + + + If present and true, a new audio session will be created and used for the transcription and the session's UUID is returned in the response object. A session will keep track of past inferences, this may be useful for things like live transcriptions where continuous audio is submitted across several requests. diff --git a/docs/src/app/api-reference/chat/page.mdx b/docs/src/app/api-reference/chat/page.mdx index 80dcc13..adf44fe 100644 --- a/docs/src/app/api-reference/chat/page.mdx +++ b/docs/src/app/api-reference/chat/page.mdx @@ -39,9 +39,120 @@ Generate text from text. {{ className: 'lead' }} + + ### Optional attributes + + + + A number in `[-2.0, 2.0]`. A higher number decreases the likelihood that the model repeats itself. + + + + + + A map of token IDs to `[-100.0, +100.0]`. Adds a percentage bias to those tokens before sampling; a value of `-100.0` prevents the token from being selected at all. + You could use this to, for example, prevent the model from emitting profanity. + + + + + + The maximum number of tokens to generate. If `None`, terminates at the first stop token or the end of sentence. + + + + + + How many choices to generate for each token in the output. `1` by default. You can use this to generate several sets of completions for the same prompt. + + + + + + A number in `[-2.0, 2.0]`. Positive values "increase the model's likelihood to talk about new topics." + + + + + + The random number generator seed for the session. Random by default. + + + + + + A stop phrase or set of stop phrases. + The server will pause emitting completions if it appears to be generating a stop phrase, and will terminate completions if a full stop phrase is detected. + Stop phrases are never emitted to the client. + + + + + + If true, stream the output as it is computed by the server, instead of returning the whole completion at the end. + You can use this to live-stream completions to a client. + + + + + + The format of the response stream. + This is always assumed to be JSON, which is non-conformant with the OpenAI spec. + + + + + + The sampling temperature, in `[0.0, 2.0]`. Higher values make the output more random. + + + + + + Nucleus sampling. If you set this value to 10%, only the top 10% of tokens are used for sampling, preventing sampling of very low-probability tokens. + + + + + + A list of tools made available to the model. + + + + + + If present, the tool that the user has chosen to use. + OpenAI states: + - `none` prevents any tool from being used, + - `auto` allows any tool to be used, or + - you can provide a description of the tool entirely instead of a name. + + + + + + A unique identifier for the _end user_ creating this request. This is used for telemetry and user tracking, and is unused within Edgen. + + + + + + Indicate if this is an isolated request, with no associated past or future context. This may allow for optimisations in some implementations. + Default: `false` + + + + + + A hint for how big a context will be. + # Warning + An unsound hint may severely drop performance and/or inference quality, and in some cases even cause Edgen to crash. Do not set this value unless you know what you are doing. + + + diff --git a/docs/src/app/api-reference/embeddings/page.mdx b/docs/src/app/api-reference/embeddings/page.mdx index 5533853..1b250b8 100644 --- a/docs/src/app/api-reference/embeddings/page.mdx +++ b/docs/src/app/api-reference/embeddings/page.mdx @@ -40,6 +40,20 @@ Generate embeddings from text. {{ className: 'lead' }} + ### Optional attributes + + + + The format to return the embeddings in. Can be either `float` or `base64`. + + + + + + The number of dimensions the resulting output embeddings should have. Only supported in some models. + + + diff --git a/docs/src/app/api-reference/image/page.mdx b/docs/src/app/api-reference/image/page.mdx index f25ab00..5a3cc64 100644 --- a/docs/src/app/api-reference/image/page.mdx +++ b/docs/src/app/api-reference/image/page.mdx @@ -39,6 +39,61 @@ Generate images from text. {{ className: 'lead' }} + ### Optional attributes + + + + The width of the generated image. + + + + + + The height of the generated image. + + + + + + The optional unconditional prompt. + + + + + + The number of steps to be used in the diffusion process. + + + + + + The number of images to generate. + Default: 1 + + + + + + The random number generator seed to use for the generation. + By default, a random seed is used. + + + + + + The guidance scale to use for generation, that is, how much should the model follow the prompt. + Values below 1 disable guidance. (the prompt is ignored) + + + + + + The Variational Auto-Encoder scale to use for generation. + Required if `model` is not a pre-made descriptor name. + This value should probably not be set, if `model` is a pre-made descriptor name. + + + diff --git a/edgen/src-tauri/Cargo.toml b/edgen/src-tauri/Cargo.toml index 3aab74c..86eb048 100644 --- a/edgen/src-tauri/Cargo.toml +++ b/edgen/src-tauri/Cargo.toml @@ -6,6 +6,7 @@ authors = ["EdgenAI"] license = "" repository = "" edition = "2021" +default-run = "edgen" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -19,16 +20,18 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { workspace = true, features = ["full", "tracing"] } tracing = { workspace = true } -opener = "0.6.1" +opener = "0.7.0" edgen_server = { path = "../../crates/edgen_server" } edgen_core = { path = "../../crates/edgen_core" } +[target.'cfg(windows)'.dependencies] +winapi = { version = "0.3", features = ["wincon"] } + [features] no_gui = [] # this feature is used for production builds or when `devPath` points to the filesystem # DO NOT REMOVE!! custom-protocol = ["tauri/custom-protocol"] -enable-windows-terminal = [] llama_vulkan = ["edgen_server/llama_vulkan"] llama_cuda = ["edgen_server/llama_cuda"] llama_metal = ["edgen_server/llama_metal"] diff --git a/edgen/src-tauri/src/main.rs b/edgen/src-tauri/src/main.rs index fb4411c..0b253d7 100644 --- a/edgen/src-tauri/src/main.rs +++ b/edgen/src-tauri/src/main.rs @@ -10,10 +10,7 @@ * limitations under the License. */ -#![cfg_attr( - not(feature = "enable-windows-terminal"), - windows_subsystem = "windows" -)] +#![cfg_attr(windows, windows_subsystem = "windows")] #[cfg(not(feature = "no_gui"))] mod gui; @@ -25,6 +22,8 @@ use once_cell::sync::Lazy; #[cfg(not(feature = "no_gui"))] fn main() -> EdgenResult { + try_attach_terminal(); + Lazy::force(&cli::PARSED_COMMANDS); match &cli::PARSED_COMMANDS.subcommand { @@ -38,6 +37,8 @@ fn main() -> EdgenResult { #[cfg(feature = "no_gui")] fn main() -> EdgenResult { + try_attach_terminal(); + Lazy::force(&cli::PARSED_COMMANDS); start(&cli::PARSED_COMMANDS) } @@ -58,3 +59,18 @@ fn serve(command: &'static cli::TopLevel, start_gui: bool) -> EdgenResult { handle.join()? } + +/// On Windows, attempt to attach to a parent process terminal if not already attached. +/// +/// This needed due to this being a Windows Subsystem binary. +fn try_attach_terminal() { + #[cfg(windows)] + { + use winapi::um::wincon; + unsafe { + if wincon::GetConsoleWindow().is_null() { + wincon::AttachConsole(wincon::ATTACH_PARENT_PROCESS); + } + } + } +}