diff --git a/config.yml b/config.yml index 9b24df9..ecb12c1 100644 --- a/config.yml +++ b/config.yml @@ -1,3 +1,6 @@ port: 8080 cache_dir: /tmp/models/ -model: 7b-open-chat-3.5 \ No newline at end of file +model: 7b-open-chat-3.5 + +# keep default model in memory +keep_in_memory: true diff --git a/src/api/routes/generate.rs b/src/api/routes/generate.rs index 4eb7cc8..ba319aa 100644 --- a/src/api/routes/generate.rs +++ b/src/api/routes/generate.rs @@ -7,7 +7,7 @@ use axum::{ use crate::{ api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest}, - config::Config, + server::AppState, }; use super::{generate_stream::generate_stream_handler, generate_text_handler}; @@ -48,12 +48,12 @@ use super::{generate_stream::generate_stream_handler, generate_text_handler}; ) )] pub async fn generate_handler( - config: State, + app_state: State, Json(payload): Json, ) -> Result)> { if payload.stream { Ok(generate_stream_handler( - config, + app_state, Json(GenerateRequest { inputs: payload.inputs, parameters: payload.parameters, @@ -63,7 +63,7 @@ pub async fn generate_handler( .into_response()) } else { Ok(generate_text_handler( - config, + app_state, Json(GenerateRequest { inputs: payload.inputs, parameters: payload.parameters, @@ -91,9 +91,13 @@ mod tests { #[ignore = "Will download model from HuggingFace"] #[tokio::test] async fn test_generate_handler_stream_enabled() { + let state = AppState { + config: Config::default(), + text_generation: None, + }; let app = Router::new() .route("/", post(generate_handler)) - .with_state(Config::default()); + .with_state(state); let response = app .oneshot( @@ -120,9 +124,13 @@ mod tests { #[tokio::test] #[ignore = "Will download model from HuggingFace"] async fn test_generate_handler_stream_disabled() { + let state = AppState { + config: Config::default(), + text_generation: None, + }; let app = Router::new() .route("/", post(generate_handler)) - .with_state(Config::default()); + .with_state(state); let response = app .oneshot( diff --git a/src/api/routes/generate_stream.rs b/src/api/routes/generate_stream.rs index ae47fac..dbaaf66 100644 --- a/src/api/routes/generate_stream.rs +++ b/src/api/routes/generate_stream.rs @@ -1,6 +1,7 @@ +use crate::api::model::GenerateRequest; use crate::llm::generate_parameter::GenerateParameter; use crate::llm::text_generation::create_text_generation; -use crate::{api::model::GenerateRequest, config::Config}; +use crate::server::AppState; use axum::{ extract::State, response::{sse::Event, IntoResponse, Sse}, @@ -38,7 +39,7 @@ use std::vec; tag = "Text Generation Inference" )] pub async fn generate_stream_handler( - config: State, + app_state: State, Json(payload): Json, ) -> impl IntoResponse { debug!("Received request: {:?}", payload); @@ -68,7 +69,12 @@ pub async fn generate_stream_handler( None => vec!["<|endoftext|>".to_string(), "".to_string()], }; - let mut generator = create_text_generation(config.model, &config.cache_dir).unwrap(); + let config = app_state.config.clone(); + + let mut generator = match &app_state.text_generation { + Some(text_generation) => text_generation.clone(), + None => create_text_generation(config.model, &config.cache_dir).unwrap(), + }; let parameter = GenerateParameter { temperature: temperature.unwrap_or_default(), diff --git a/src/api/routes/generate_text.rs b/src/api/routes/generate_text.rs index 6c95b68..085ac41 100644 --- a/src/api/routes/generate_text.rs +++ b/src/api/routes/generate_text.rs @@ -1,7 +1,7 @@ use crate::{ api::model::{ErrorResponse, GenerateRequest, GenerateResponse}, - config::Config, llm::{generate_parameter::GenerateParameter, text_generation::create_text_generation}, + server::AppState, }; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; @@ -40,7 +40,7 @@ use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; tag = "Text Generation Inference" )] pub async fn generate_text_handler( - config: State, + app_state: State, Json(payload): Json, ) -> impl IntoResponse { let temperature = match &payload.parameters { @@ -64,45 +64,40 @@ pub async fn generate_text_handler( None => 50, }; - let generator = create_text_generation(config.model, &config.cache_dir); - match generator { - Ok(mut generator) => { - let parameter = GenerateParameter { - temperature: temperature.unwrap_or_default(), - top_p: top_p.unwrap_or_default(), - max_new_tokens: sample_len, - seed: 42, - repeat_penalty, - repeat_last_n, - }; + let config = app_state.config.clone(); - let generated_text = generator.run(&payload.inputs, parameter); - match generated_text { - Ok(generated_text) => match generated_text { - Some(text) => Ok(Json(GenerateResponse { - generated_text: text, - })), - None => Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".to_string(), - error_type: None, - }), - )), - }, - Err(_) => Err(( - StatusCode::FAILED_DEPENDENCY, - Json(ErrorResponse { - error: "Request failed during generation".to_string(), - error_type: None, - }), - )), - } - } + let mut generator = match &app_state.text_generation { + Some(text_generation) => text_generation.clone(), + None => create_text_generation(config.model, &config.cache_dir).unwrap(), + }; + + let parameter = GenerateParameter { + temperature: temperature.unwrap_or_default(), + top_p: top_p.unwrap_or_default(), + max_new_tokens: sample_len, + seed: 42, + repeat_penalty, + repeat_last_n, + }; + + let generated_text = generator.run(&payload.inputs, parameter); + match generated_text { + Ok(generated_text) => match generated_text { + Some(text) => Ok(Json(GenerateResponse { + generated_text: text, + })), + None => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".to_string(), + error_type: None, + }), + )), + }, Err(_) => Err(( - StatusCode::TOO_MANY_REQUESTS, + StatusCode::FAILED_DEPENDENCY, Json(ErrorResponse { - error: "Model is overloaded".to_string(), + error: "Request failed during generation".to_string(), error_type: None, }), )), diff --git a/src/api/routes/info.rs b/src/api/routes/info.rs index d3333ef..a2b40f8 100644 --- a/src/api/routes/info.rs +++ b/src/api/routes/info.rs @@ -1,6 +1,6 @@ //! This module contains the endpoint for retrieving model information. -use crate::{api::model::Info, config::Config}; +use crate::{api::model::Info, server::AppState}; use axum::{extract::State, http::StatusCode, Json}; /// Endpoint to get model information. @@ -15,7 +15,8 @@ use axum::{extract::State, http::StatusCode, Json}; ), tag = "Text Generation Inference" )] -pub async fn get_info_handler(config: State) -> Result, StatusCode> { +pub async fn get_info_handler(app_state: State) -> Result, StatusCode> { + let config = &app_state.config; let version = env!("CARGO_PKG_VERSION"); let model_info = Info { docker_label: None, @@ -51,9 +52,13 @@ mod tests { port: 8080, cache_dir: None, model: Models::default(), + keep_in_memory: None, }; - let state = State(test_config.clone()); + let state = State(AppState { + config: test_config.clone(), + text_generation: None, + }); let response = get_info_handler(state).await.unwrap(); let info = response.0; assert_eq!(info.max_batch_total_tokens, 2048); diff --git a/src/api/routes/model.rs b/src/api/routes/model.rs index 10ba05f..21a62b6 100644 --- a/src/api/routes/model.rs +++ b/src/api/routes/model.rs @@ -7,8 +7,8 @@ use axum::{ use crate::{ api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest}, - config::Config, llm::models::Models, + server::AppState, }; use super::{generate_stream::generate_stream_handler, generate_text_handler}; @@ -54,15 +54,15 @@ use super::{generate_stream::generate_stream_handler, generate_text_handler}; pub async fn generate_model_handler( Path(model): Path, - config: State, + app_state: State, Json(payload): Json, ) -> Result)> { - let mut config = config.clone(); - config.model = model; + let mut app_state = app_state.clone(); + app_state.config.model = model; if payload.stream { Ok(generate_stream_handler( - config, + app_state, Json(GenerateRequest { inputs: payload.inputs, parameters: payload.parameters, @@ -72,7 +72,7 @@ pub async fn generate_model_handler( .into_response()) } else { Ok(generate_text_handler( - config, + app_state, Json(GenerateRequest { inputs: payload.inputs, parameters: payload.parameters, diff --git a/src/config.rs b/src/config.rs index a6fd07a..943d2d9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -18,6 +18,9 @@ pub struct Config { /// Model to be used by the server. pub model: Models, + + /// Whether to keep the default model in memory. + pub keep_in_memory: Option, } /// Loads the application configuration from a YAML file. @@ -67,5 +70,6 @@ mod tests { assert_eq!(config.port, 8080); assert_eq!(config.cache_dir, Some(PathBuf::from("/tmp"))); assert_eq!(config.model, Models::OpenChat35); + assert_eq!(config.keep_in_memory, None); } } diff --git a/src/llm/text_generation.rs b/src/llm/text_generation.rs index 738107c..592396c 100644 --- a/src/llm/text_generation.rs +++ b/src/llm/text_generation.rs @@ -26,6 +26,7 @@ use super::{ Model, }; +#[derive(Clone)] pub struct TextGeneration { model: Arc>, tokenizer: Arc>, diff --git a/src/main.rs b/src/main.rs index c52ec0b..04dc1ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use chat_flame_backend::{ config::{load_config, Config}, llm::{ - generate_parameter::GenerateParameter, loader::create_model, models::Models, + generate_parameter::GenerateParameter, models::Models, text_generation::create_text_generation, }, server::server, @@ -72,11 +72,15 @@ async fn generate_text( async fn start_server(model: Models, config: Config) { info!("Starting server"); info!("preload model"); - let _ = create_model(model, &config.cache_dir); + let text_generation = create_text_generation(model, &config.cache_dir).unwrap(); info!("Running on port: {}", config.port); let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); - let app = server(config); + + let app = match &config.keep_in_memory { + Some(true) => server(config, Some(text_generation)), + _ => server(config, None), + }; info!("Server running at http://{}", addr); diff --git a/src/server.rs b/src/server.rs index e151714..3fa7775 100644 --- a/src/server.rs +++ b/src/server.rs @@ -16,8 +16,15 @@ use crate::{ routes::{get_health_handler, get_info_handler}, }, config::Config, + llm::text_generation::TextGeneration, }; +#[derive(Clone)] +pub struct AppState { + pub config: Config, + pub text_generation: Option, +} + /// Creates and configures the Axum web server with various routes and Swagger UI. /// /// This function sets up all the necessary routes for the API and merges them @@ -30,7 +37,7 @@ use crate::{ /// # Returns /// /// An instance of `axum::Router` configured with all routes and the Swagger UI. -pub fn server(config: Config) -> Router { +pub fn server(config: Config, text_generation: Option) -> Router { let router = Router::new() .route("/", get(|| async { Redirect::permanent("/swagger-ui") })) .route("/", post(generate_handler)) @@ -39,7 +46,10 @@ pub fn server(config: Config) -> Router { .route("/info", get(get_info_handler)) .route("/generate_stream", post(generate_stream_handler)) .route("/model/:model/", post(generate_model_handler)) - .with_state(config); + .with_state(AppState { + config: config.clone(), + text_generation, + }); let swagger_ui = SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()); @@ -58,7 +68,7 @@ mod tests { #[tokio::test] async fn test_root_redirects_to_swagger_ui() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let req = Request::builder() .method("GET") @@ -76,7 +86,7 @@ mod tests { #[tokio::test] async fn test_swagger_ui_endpoint() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let req = Request::builder() .method("GET") diff --git a/tests/server_test.rs b/tests/server_test.rs index 490199d..3c64ef2 100644 --- a/tests/server_test.rs +++ b/tests/server_test.rs @@ -6,7 +6,7 @@ use chat_flame_backend::server::server; #[tokio::test] async fn test_generate_handler() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let server = TestServer::new(app).unwrap(); let response = server @@ -31,7 +31,7 @@ async fn test_generate_handler() { #[tokio::test] async fn test_generate_text_handler() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let server = TestServer::new(app).unwrap(); let response = server @@ -56,7 +56,7 @@ async fn test_generate_text_handler() { #[tokio::test] async fn test_generate_text_model_handler() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let server = TestServer::new(app).unwrap(); let response = server @@ -80,7 +80,7 @@ async fn test_generate_text_model_handler() { #[tokio::test] async fn test_get_health_handler() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let server = TestServer::new(app).unwrap(); let response = server.get("/health").await; @@ -91,7 +91,7 @@ async fn test_get_health_handler() { #[tokio::test] async fn test_get_info_handler() { let config = Config::default(); - let app = server(config); + let app = server(config, None); let server = TestServer::new(app).unwrap(); let response = server.get("/info").await;