Skip to content

Commit

Permalink
✨ cache model in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Jan 21, 2024
1 parent e90f976 commit ce02587
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 69 deletions.
5 changes: 4 additions & 1 deletion config.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
port: 8080
cache_dir: /tmp/models/
model: 7b-open-chat-3.5
model: 7b-open-chat-3.5

# keep default model in memory
keep_in_memory: true
20 changes: 14 additions & 6 deletions src/api/routes/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use axum::{

use crate::{
api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest},
config::Config,
server::AppState,
};

use super::{generate_stream::generate_stream_handler, generate_text_handler};
Expand Down Expand Up @@ -48,12 +48,12 @@ use super::{generate_stream::generate_stream_handler, generate_text_handler};
)
)]
pub async fn generate_handler(
config: State<Config>,
app_state: State<AppState>,
Json(payload): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
if payload.stream {
Ok(generate_stream_handler(
config,
app_state,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
Expand All @@ -63,7 +63,7 @@ pub async fn generate_handler(
.into_response())
} else {
Ok(generate_text_handler(
config,
app_state,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
Expand Down Expand Up @@ -91,9 +91,13 @@ mod tests {
#[ignore = "Will download model from HuggingFace"]
#[tokio::test]
async fn test_generate_handler_stream_enabled() {
let state = AppState {
config: Config::default(),
text_generation: None,
};
let app = Router::new()
.route("/", post(generate_handler))
.with_state(Config::default());
.with_state(state);

let response = app
.oneshot(
Expand All @@ -120,9 +124,13 @@ mod tests {
#[tokio::test]
#[ignore = "Will download model from HuggingFace"]
async fn test_generate_handler_stream_disabled() {
let state = AppState {
config: Config::default(),
text_generation: None,
};
let app = Router::new()
.route("/", post(generate_handler))
.with_state(Config::default());
.with_state(state);

let response = app
.oneshot(
Expand Down
12 changes: 9 additions & 3 deletions src/api/routes/generate_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::api::model::GenerateRequest;
use crate::llm::generate_parameter::GenerateParameter;
use crate::llm::text_generation::create_text_generation;
use crate::{api::model::GenerateRequest, config::Config};
use crate::server::AppState;
use axum::{
extract::State,
response::{sse::Event, IntoResponse, Sse},
Expand Down Expand Up @@ -38,7 +39,7 @@ use std::vec;
tag = "Text Generation Inference"
)]
pub async fn generate_stream_handler(
config: State<Config>,
app_state: State<AppState>,
Json(payload): Json<GenerateRequest>,
) -> impl IntoResponse {
debug!("Received request: {:?}", payload);
Expand Down Expand Up @@ -68,7 +69,12 @@ pub async fn generate_stream_handler(
None => vec!["<|endoftext|>".to_string(), "</s>".to_string()],
};

let mut generator = create_text_generation(config.model, &config.cache_dir).unwrap();
let config = app_state.config.clone();

let mut generator = match &app_state.text_generation {
Some(text_generation) => text_generation.clone(),
None => create_text_generation(config.model, &config.cache_dir).unwrap(),
};

let parameter = GenerateParameter {
temperature: temperature.unwrap_or_default(),
Expand Down
71 changes: 33 additions & 38 deletions src/api/routes/generate_text.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
api::model::{ErrorResponse, GenerateRequest, GenerateResponse},
config::Config,
llm::{generate_parameter::GenerateParameter, text_generation::create_text_generation},
server::AppState,
};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};

Expand Down Expand Up @@ -40,7 +40,7 @@ use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
tag = "Text Generation Inference"
)]
pub async fn generate_text_handler(
config: State<Config>,
app_state: State<AppState>,
Json(payload): Json<GenerateRequest>,
) -> impl IntoResponse {
let temperature = match &payload.parameters {
Expand All @@ -64,45 +64,40 @@ pub async fn generate_text_handler(
None => 50,
};

let generator = create_text_generation(config.model, &config.cache_dir);
match generator {
Ok(mut generator) => {
let parameter = GenerateParameter {
temperature: temperature.unwrap_or_default(),
top_p: top_p.unwrap_or_default(),
max_new_tokens: sample_len,
seed: 42,
repeat_penalty,
repeat_last_n,
};
let config = app_state.config.clone();

let generated_text = generator.run(&payload.inputs, parameter);
match generated_text {
Ok(generated_text) => match generated_text {
Some(text) => Ok(Json(GenerateResponse {
generated_text: text,
})),
None => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".to_string(),
error_type: None,
}),
)),
},
Err(_) => Err((
StatusCode::FAILED_DEPENDENCY,
Json(ErrorResponse {
error: "Request failed during generation".to_string(),
error_type: None,
}),
)),
}
}
let mut generator = match &app_state.text_generation {
Some(text_generation) => text_generation.clone(),
None => create_text_generation(config.model, &config.cache_dir).unwrap(),
};

let parameter = GenerateParameter {
temperature: temperature.unwrap_or_default(),
top_p: top_p.unwrap_or_default(),
max_new_tokens: sample_len,
seed: 42,
repeat_penalty,
repeat_last_n,
};

let generated_text = generator.run(&payload.inputs, parameter);
match generated_text {
Ok(generated_text) => match generated_text {
Some(text) => Ok(Json(GenerateResponse {
generated_text: text,
})),
None => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".to_string(),
error_type: None,
}),
)),
},
Err(_) => Err((
StatusCode::TOO_MANY_REQUESTS,
StatusCode::FAILED_DEPENDENCY,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
error: "Request failed during generation".to_string(),
error_type: None,
}),
)),
Expand Down
11 changes: 8 additions & 3 deletions src/api/routes/info.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module contains the endpoint for retrieving model information.

use crate::{api::model::Info, config::Config};
use crate::{api::model::Info, server::AppState};
use axum::{extract::State, http::StatusCode, Json};

/// Endpoint to get model information.
Expand All @@ -15,7 +15,8 @@ use axum::{extract::State, http::StatusCode, Json};
),
tag = "Text Generation Inference"
)]
pub async fn get_info_handler(config: State<Config>) -> Result<Json<Info>, StatusCode> {
pub async fn get_info_handler(app_state: State<AppState>) -> Result<Json<Info>, StatusCode> {
let config = &app_state.config;
let version = env!("CARGO_PKG_VERSION");
let model_info = Info {
docker_label: None,
Expand Down Expand Up @@ -51,9 +52,13 @@ mod tests {
port: 8080,
cache_dir: None,
model: Models::default(),
keep_in_memory: None,
};

let state = State(test_config.clone());
let state = State(AppState {
config: test_config.clone(),
text_generation: None,
});
let response = get_info_handler(state).await.unwrap();
let info = response.0;
assert_eq!(info.max_batch_total_tokens, 2048);
Expand Down
12 changes: 6 additions & 6 deletions src/api/routes/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use axum::{

use crate::{
api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest},
config::Config,
llm::models::Models,
server::AppState,
};

use super::{generate_stream::generate_stream_handler, generate_text_handler};
Expand Down Expand Up @@ -54,15 +54,15 @@ use super::{generate_stream::generate_stream_handler, generate_text_handler};

pub async fn generate_model_handler(
Path(model): Path<Models>,
config: State<Config>,
app_state: State<AppState>,
Json(payload): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut config = config.clone();
config.model = model;
let mut app_state = app_state.clone();
app_state.config.model = model;

if payload.stream {
Ok(generate_stream_handler(
config,
app_state,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
Expand All @@ -72,7 +72,7 @@ pub async fn generate_model_handler(
.into_response())
} else {
Ok(generate_text_handler(
config,
app_state,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
Expand Down
4 changes: 4 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub struct Config {

/// Model to be used by the server.
pub model: Models,

/// Whether to keep the default model in memory.
pub keep_in_memory: Option<bool>,
}

/// Loads the application configuration from a YAML file.
Expand Down Expand Up @@ -67,5 +70,6 @@ mod tests {
assert_eq!(config.port, 8080);
assert_eq!(config.cache_dir, Some(PathBuf::from("/tmp")));
assert_eq!(config.model, Models::OpenChat35);
assert_eq!(config.keep_in_memory, None);
}
}
1 change: 1 addition & 0 deletions src/llm/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use super::{
Model,
};

#[derive(Clone)]
pub struct TextGeneration {
model: Arc<Mutex<Model>>,
tokenizer: Arc<Mutex<TokenOutputStream>>,
Expand Down
10 changes: 7 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use chat_flame_backend::{
config::{load_config, Config},
llm::{
generate_parameter::GenerateParameter, loader::create_model, models::Models,
generate_parameter::GenerateParameter, models::Models,
text_generation::create_text_generation,
},
server::server,
Expand Down Expand Up @@ -72,11 +72,15 @@ async fn generate_text(
async fn start_server(model: Models, config: Config) {
info!("Starting server");
info!("preload model");
let _ = create_model(model, &config.cache_dir);
let text_generation = create_text_generation(model, &config.cache_dir).unwrap();

info!("Running on port: {}", config.port);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let app = server(config);

let app = match &config.keep_in_memory {
Some(true) => server(config, Some(text_generation)),
_ => server(config, None),
};

info!("Server running at http://{}", addr);

Expand Down
18 changes: 14 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ use crate::{
routes::{get_health_handler, get_info_handler},
},
config::Config,
llm::text_generation::TextGeneration,
};

#[derive(Clone)]
pub struct AppState {
pub config: Config,
pub text_generation: Option<TextGeneration>,
}

/// Creates and configures the Axum web server with various routes and Swagger UI.
///
/// This function sets up all the necessary routes for the API and merges them
Expand All @@ -30,7 +37,7 @@ use crate::{
/// # Returns
///
/// An instance of `axum::Router` configured with all routes and the Swagger UI.
pub fn server(config: Config) -> Router {
pub fn server(config: Config, text_generation: Option<TextGeneration>) -> Router {
let router = Router::new()
.route("/", get(|| async { Redirect::permanent("/swagger-ui") }))
.route("/", post(generate_handler))
Expand All @@ -39,7 +46,10 @@ pub fn server(config: Config) -> Router {
.route("/info", get(get_info_handler))
.route("/generate_stream", post(generate_stream_handler))
.route("/model/:model/", post(generate_model_handler))
.with_state(config);
.with_state(AppState {
config: config.clone(),
text_generation,
});

let swagger_ui = SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi());

Expand All @@ -58,7 +68,7 @@ mod tests {
#[tokio::test]
async fn test_root_redirects_to_swagger_ui() {
let config = Config::default();
let app = server(config);
let app = server(config, None);

let req = Request::builder()
.method("GET")
Expand All @@ -76,7 +86,7 @@ mod tests {
#[tokio::test]
async fn test_swagger_ui_endpoint() {
let config = Config::default();
let app = server(config);
let app = server(config, None);

let req = Request::builder()
.method("GET")
Expand Down
Loading

0 comments on commit ce02587

Please sign in to comment.