Skip to content

Commit

Permalink
✨ implements non streaming generate endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 26, 2023
1 parent 3a54e9c commit 0f26964
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 35 deletions.
2 changes: 2 additions & 0 deletions src/api/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use utoipa::OpenApi;
#[openapi(
// List of API endpoints to be included in the documentation.
paths(
super::routes::generate::generate_handler,
super::routes::generate_text::generate_text_handler,
super::routes::generate_stream::generate_stream_handler,
super::routes::health::get_health_handler,
Expand Down Expand Up @@ -50,6 +51,7 @@ mod tests {
fn api_doc_contains_all_endpoints() {
let api_doc = ApiDoc::openapi();
let paths = api_doc.paths.paths;
assert!(paths.contains_key("/"));
assert!(paths.contains_key("/generate"));
assert!(paths.contains_key("/generate_stream"));
assert!(paths.contains_key("/health"));
Expand Down
78 changes: 50 additions & 28 deletions src/api/routes/generate.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,77 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
Json,
};

use crate::{
api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest},
config::Config,
};

use super::generate_stream::generate_stream_handler;
use super::{generate_stream::generate_stream_handler, generate_text_handler};

/// Handler for generating text tokens.
///
/// This endpoint accepts a `CompatGenerateRequest` and returns a stream of generated text.
/// It requires the `stream` field in the request to be true. If `stream` is false,
/// the handler will return a `StatusCode::NOT_IMPLEMENTED` error.
/// This endpoint accepts a `CompatGenerateRequest` and returns a stream of generated text
/// or a single text response based on the `stream` field in the request. If `stream` is true,
/// it returns a stream of `StreamResponse`. If `stream` is false, it returns `GenerateResponse`.
///
/// # Arguments
/// * `config` - State containing the application configuration.
/// * `payload` - JSON payload containing the input text and optional parameters.
///
/// # Responses
/// * `200 OK` - Successful generation of text, returns a stream of `StreamResponse`.
/// * `501 Not Implemented` - Returned if `stream` field in request is false.
/// * `200 OK` - Successful generation of text.
/// * `501 Not Implemented` - Returned if streaming is not implemented.
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/",
request_body = CompatGenerateRequest,
responses(
(status = 200, description = "Generated Text", body = StreamResponse),
(status = 501, description = "Streaming not enabled", body = ErrorResponse),
),
tag = "Text Generation Inference"
(status = 200, description = "Generated Text",
content(
("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse),
)
),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json!({"error": "Incomplete generation"})),
)
)]
pub async fn generate_handler(
config: State<Config>,
Json(payload): Json<CompatGenerateRequest>,
) -> impl IntoResponse {
if !payload.stream {
return Err((
StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse {
error: "Use /generate endpoint if not streaming".to_string(),
error_type: None,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
if payload.stream {
Ok(generate_stream_handler(
config,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
}),
)
.await
.into_response())
} else {
Ok(generate_text_handler(
config,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
}),
));
)
.await
.into_response())
}
Ok(generate_stream_handler(
config,
Json(GenerateRequest {
inputs: payload.inputs,
parameters: payload.parameters,
}),
)
.await)
}

#[cfg(test)]
Expand Down Expand Up @@ -97,6 +118,7 @@ mod tests {

/// Test the generate_handler function for streaming disabled.
#[tokio::test]
#[ignore = "Will download model from HuggingFace"]
async fn test_generate_handler_stream_disabled() {
let app = Router::new()
.route("/", post(generate_handler))
Expand All @@ -120,6 +142,6 @@ mod tests {
.await
.unwrap();

assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
assert_eq!(response.status(), StatusCode::OK);
}
}
19 changes: 18 additions & 1 deletion src/api/routes/generate_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,24 @@ use futures::stream::StreamExt;
use log::debug;
use std::vec;

/// Generate tokens
/// Asynchronous handler for generating text through a streaming API.
///
/// This function handles POST requests to the `/generate_stream` endpoint. It takes a JSON payload
/// representing a `GenerateRequest` and uses the configuration and parameters specified to
/// generate text using a streaming approach. The response is a stream of Server-Sent Events (SSE),
/// allowing clients to receive generated text in real-time as it is produced.
///
/// # Parameters
/// - `config`: Application state holding the global configuration.
/// - `Json(payload)`: JSON payload containing the input text and generation parameters.
///
/// # Responses
/// - `200 OK`: Stream of generated text as `StreamResponse` events.
/// - Error responses: Descriptive error messages if any issues occur.
///
/// # Usage
/// This endpoint is suitable for scenarios where real-time text generation is required,
/// such as interactive chatbots or live content creation tools.
#[utoipa::path(
post,
path = "/generate_stream",
Expand Down
22 changes: 21 additions & 1 deletion src/api/routes/generate_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,27 @@ use crate::{
};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};

/// Generate tokens
/// Asynchronous handler for generating text.
///
/// This function handles POST requests to the `/generate` endpoint. It takes a JSON payload
/// representing a `GenerateRequest` and uses the configuration and parameters specified to
/// generate text. The generated text is returned in a `GenerateResponse` if successful.
///
/// # Parameters
/// - `config`: Application state holding the global configuration.
/// - `Json(payload)`: JSON payload containing the input text and generation parameters.
///
/// # Responses
/// - `200 OK`: Successful text generation with `GenerateResponse`.
/// - `422 Unprocessable Entity`: Input validation error with `ErrorResponse`.
/// - `424 Failed Dependency`: Generation error with `ErrorResponse`.
/// - `429 Too Many Requests`: Model is overloaded with `ErrorResponse`.
/// - `500 Internal Server Error`: Incomplete generation with `ErrorResponse`.
///
/// # Usage
/// This endpoint is suitable for generating text based on given prompts and parameters.
/// It can be used in scenarios where batch text generation is required, such as content
/// creation, language modeling, or any application needing on-demand text generation.
#[utoipa::path(
post,
path = "/generate",
Expand Down
23 changes: 18 additions & 5 deletions src/api/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
pub mod generate;
pub mod generate_stream;
pub mod generate_text;
pub mod health;
pub mod info;
/// Module containing all route handlers.
///
/// This module organizes the different API endpoints and their associated handlers.
/// Each route corresponds to a specific functionality of the text generation inference API.
///
/// # Modules
/// * `generate` - Handles requests for token generation with streaming capability.
/// * `generate_stream` - Handles streaming requests for text generation.
/// * `generate_text` - Handles requests for generating text without streaming.
/// * `health` - Provides a health check endpoint.
/// * `info` - Provides information about the text generation inference service.
pub mod generate; // Module for handling token generation with streaming.
pub mod generate_stream; // Module for handling streaming text generation requests.
pub mod generate_text; // Module for handling text generation requests.
pub mod health; // Module for the health check endpoint.
pub mod info; // Module for the service information endpoint.

// Public exports of route handlers for ease of access.
pub use generate::generate_handler;
pub use generate_stream::generate_stream_handler;
pub use generate_text::generate_text_handler;
Expand Down

0 comments on commit 0f26964

Please sign in to comment.