diff --git a/Cargo.toml b/Cargo.toml index 1047b2b..6a87f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,3 +40,4 @@ utoipa-swagger-ui = { version = "5", features = ["axum"] } axum-test = "14.0.0" tempfile = "3.8.1" tokio = { version = "1.0", features = ["full", "test-util"] } +tower = "0.4.13" \ No newline at end of file diff --git a/src/api/openapi.rs b/src/api/openapi.rs index 623ef07..ecbbf30 100644 --- a/src/api/openapi.rs +++ b/src/api/openapi.rs @@ -1,19 +1,58 @@ -use utoipa::OpenApi; - -use crate::api::model::ErrorResponse; - use super::model::{ CompatGenerateRequest, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, Info, StreamDetails, StreamResponse, Token, }; +use crate::api::model::ErrorResponse; +use utoipa::OpenApi; +/// Represents the API documentation for the text generation inference service. +/// +/// This struct uses `utoipa::OpenApi` to provide a centralized documentation of the API endpoints +/// and their associated request and response models. It is used to generate OpenAPI specification +/// for the service, which can be served as a Swagger UI or other OpenAPI-compatible documentation tools. #[derive(OpenApi)] #[openapi( - paths(super::routes::generate_text::generate_text_handler, super::routes::generate_stream::generate_stream_handler, super::routes::health::get_health_handler, super::routes::info::get_info_handler), + // List of API endpoints to be included in the documentation. + paths( + super::routes::generate_text::generate_text_handler, + super::routes::generate_stream::generate_stream_handler, + super::routes::health::get_health_handler, + super::routes::info::get_info_handler + ), + // Schema components for requests and responses used across the API. components( - schemas(CompatGenerateRequest, GenerateRequest, GenerateResponse, GenerateParameters, ErrorResponse, StreamResponse, - StreamDetails, Token, FinishReason, Info) + schemas( + CompatGenerateRequest, + GenerateRequest, + GenerateResponse, + GenerateParameters, + ErrorResponse, + StreamResponse, + StreamDetails, + Token, + FinishReason, + Info + ) ), - tags((name = "Text Generation Inference", description = "Text generation Inference API")) + // Metadata and description of the API tags. + tags( + (name = "Text Generation Inference", description = "Text generation Inference API") + ) )] pub struct ApiDoc; + +#[cfg(test)] +mod tests { + use super::*; + use utoipa::OpenApi; + + #[test] + fn api_doc_contains_all_endpoints() { + let api_doc = ApiDoc::openapi(); + let paths = api_doc.paths.paths; + assert!(paths.contains_key("/generate")); + assert!(paths.contains_key("/generate_stream")); + assert!(paths.contains_key("/health")); + assert!(paths.contains_key("/info")); + } +} diff --git a/src/server.rs b/src/server.rs index 2e2b109..c2a530a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ use axum::{ + response::Redirect, routing::{get, post}, Router, }; @@ -14,8 +15,22 @@ use crate::{ config::Config, }; +/// Creates and configures the Axum web server with various routes and Swagger UI. +/// +/// This function sets up all the necessary routes for the API and merges them +/// with the Swagger UI for easy API documentation and testing. +/// +/// # Arguments +/// +/// * `config` - Configuration settings for the server. +/// +/// # Returns +/// +/// An instance of `axum::Router` configured with all routes and the Swagger UI. + pub fn server(config: Config) -> Router { let router = Router::new() + .route("/", get(|| async { Redirect::permanent("/swagger-ui") })) .route("/", post(generate_handler)) .route("/generate", post(generate_text_handler)) .route("/health", get(get_health_handler)) @@ -27,3 +42,46 @@ pub fn server(config: Config) -> Router { router.merge(swagger_ui) } + +#[cfg(test)] +mod tests { + use super::*; + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use tower::ServiceExt; // for `oneshot` function + + #[tokio::test] + async fn test_root_redirects_to_swagger_ui() { + let config = Config::default(); + let app = server(config); + + let req = Request::builder() + .method("GET") + .uri("/") + .body(Body::empty()) + .unwrap(); + + let response = app.clone().oneshot(req).await.unwrap(); + + // Verify that the response is a redirect to /swagger-ui. + assert_eq!(response.status().as_u16(), 308); + assert_eq!(response.headers().get("location").unwrap(), "/swagger-ui"); + } + + #[tokio::test] + async fn test_swagger_ui_endpoint() { + let config = Config::default(); + let app = server(config); + + let req = Request::builder() + .method("GET") + .uri("/swagger-ui/index.html") + .body(Body::empty()) + .unwrap(); + + let response = app.clone().oneshot(req).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + } +}