Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add api_key for request authorization #211

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,21 @@ Options:
[env: PAYLOAD_LIMIT=]
[default: 2000000]

--api-key <API_KEY>
Set an api key for request authorization.

By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.

[env: API_KEY=]

--json-output
Outputs the logs in JSON format (useful for telemetry)

[env: JSON_OUTPUT=]

--otlp-endpoint <OTLP_ENDPOINT>
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC.
e.g. `http://localhost:4317`
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC. e.g. `http://localhost:4317`

[env: OTLP_ENDPOINT=]

--cors-allow-origin <CORS_ALLOW_ORIGIN>
Expand Down
17 changes: 17 additions & 0 deletions docs/source/en/cli_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,29 @@ Options:

[env: HUGGINGFACE_HUB_CACHE=/data]

--payload-limit <PAYLOAD_LIMIT>
Payload size limit in bytes

Default is 2MB

[env: PAYLOAD_LIMIT=]
[default: 2000000]

--api-key <API_KEY>
Set an api key for request authorization.

By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.

[env: API_KEY=]

--json-output
Outputs the logs in JSON format (useful for telemetry)

[env: JSON_OUTPUT=]

--otlp-endpoint <OTLP_ENDPOINT>
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC. e.g. `http://localhost:4317`

[env: OTLP_ENDPOINT=]

--cors-allow-origin <CORS_ALLOW_ORIGIN>
Expand Down
50 changes: 40 additions & 10 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@ pub async fn run(
info: Info,
addr: SocketAddr,
prom_builder: PrometheusBuilder,
api_key: Option<String>,
) -> Result<(), anyhow::Error> {
prom_builder.install()?;
tracing::info!("Serving Prometheus metrics: 0.0.0.0:9000");
Expand Down Expand Up @@ -1431,17 +1432,46 @@ pub async fn run(
let service = TextEmbeddingsService::new(infer, info);

// Create gRPC server
let server = if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();

let auth = move |req: Request<()>| -> Result<Request<()>, Status> {
match req.metadata().get("authorization") {
Some(t) if t == api_key => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
}
};

Server::builder()
.add_service(health_service)
.add_service(reflection_service)
.add_service(grpc::InfoServer::with_interceptor(service.clone(), auth))
.add_service(grpc::TokenizeServer::with_interceptor(
service.clone(),
auth,
))
.add_service(grpc::EmbedServer::with_interceptor(service.clone(), auth))
.add_service(grpc::PredictServer::with_interceptor(service.clone(), auth))
.add_service(grpc::RerankServer::with_interceptor(service, auth))
.serve_with_shutdown(addr, shutdown::shutdown_signal())
} else {
Server::builder()
.add_service(health_service)
.add_service(reflection_service)
.add_service(grpc::InfoServer::new(service.clone()))
.add_service(grpc::TokenizeServer::new(service.clone()))
.add_service(grpc::EmbedServer::new(service.clone()))
.add_service(grpc::PredictServer::new(service.clone()))
.add_service(grpc::RerankServer::new(service))
.serve_with_shutdown(addr, shutdown::shutdown_signal())
};

tracing::info!("Starting gRPC server: {}", &addr);
Server::builder()
.add_service(health_service)
.add_service(reflection_service)
.add_service(grpc::InfoServer::new(service.clone()))
.add_service(grpc::TokenizeServer::new(service.clone()))
.add_service(grpc::EmbedServer::new(service.clone()))
.add_service(grpc::PredictServer::new(service.clone()))
.add_service(grpc::RerankServer::new(service))
.serve_with_shutdown(addr, shutdown::shutdown_signal())
.await?;
server.await?;

Ok(())
}
Expand Down
26 changes: 25 additions & 1 deletion router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::future::join_all;
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -1263,6 +1264,7 @@ pub async fn run(
addr: SocketAddr,
prom_builder: PrometheusBuilder,
payload_limit: usize,
api_key: Option<String>,
cors_allow_origin: Option<Vec<String>>,
) -> Result<(), anyhow::Error> {
// OpenAPI documentation
Expand Down Expand Up @@ -1434,13 +1436,35 @@ pub async fn run(
}
}

let app = app
app = app
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(cors_layer);

if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();

let auth = move |headers: HeaderMap,
request: axum::extract::Request,
next: axum::middleware::Next| async move {
match headers.get(AUTHORIZATION) {
Some(token) if token == api_key => {
let response = next.run(request).await;
Ok(response)
}
_ => Err(StatusCode::UNAUTHORIZED),
}
};

app = app.layer(axum::middleware::from_fn(auth));
}

// Run server
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();

Expand Down
10 changes: 7 additions & 3 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub async fn run(
uds_path: Option<String>,
huggingface_hub_cache: Option<String>,
payload_limit: usize,
api_key: Option<String>,
otlp_endpoint: Option<String>,
cors_allow_origin: Option<Vec<String>>,
) -> Result<()> {
Expand Down Expand Up @@ -275,6 +276,7 @@ pub async fn run(
addr,
prom_builder,
payload_limit,
api_key,
cors_allow_origin,
)
.await
Expand All @@ -285,10 +287,12 @@ pub async fn run(

#[cfg(feature = "grpc")]
{
// cors_allow_origin is not used for gRPC servers
// cors_allow_origin and payload_limit are not used for gRPC servers
let _ = cors_allow_origin;
let server =
tokio::spawn(async move { grpc::server::run(infer, info, addr, prom_builder).await });
let _ = payload_limit;
let server = tokio::spawn(async move {
grpc::server::run(infer, info, addr, prom_builder, api_key).await
});
tracing::info!("Ready");
server.await??;
}
Expand Down
7 changes: 7 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ struct Args {
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,

/// Set an api key for request authorization.
///
/// By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
#[clap(long, env)]
api_key: Option<String>,

/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
Expand Down Expand Up @@ -143,6 +149,7 @@ async fn main() -> Result<()> {
Some(args.uds_path),
args.huggingface_hub_cache,
args.payload_limit,
args.api_key,
args.otlp_endpoint,
args.cors_allow_origin,
)
Expand Down
Loading