diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 77f88490..ce98876f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -349,6 +349,12 @@ Options: --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] +``` +## API_KEY +```shell + --api-key + [env: API_KEY=] + ``` ## WATERMARK_GAMMA ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 228b0e79..9fbc2d84 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -422,6 +422,10 @@ struct Args { #[clap(long, env)] cors_allow_origin: Vec, + + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] watermark_gamma: Option, #[clap(long, env)] @@ -1265,12 +1269,17 @@ fn spawn_webserver( router_args.push("--otlp-service-name".to_string()); router_args.push(otlp_service_name); - // CORS origins + // API Key for origin in args.cors_allow_origin.into_iter() { router_args.push("--cors-allow-origin".to_string()); router_args.push(origin); } + // OpenTelemetry + if let Some(api_key) = args.api_key{ + router_args.push("--api-key".to_string()); + router_args.push(api_key); + } // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index bfc77913..36879aa4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -77,6 +77,8 @@ struct Args { #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] + api_key: Option, + #[clap(long, env)] ngrok: bool, #[clap(long, env)] ngrok_authtoken: Option, @@ -127,6 +129,7 @@ async fn main() -> Result<(), RouterError> { otlp_endpoint, otlp_service_name, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, @@ -446,6 +449,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, addr, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, diff --git a/router/src/server.rs b/router/src/server.rs index c56c39a3..3f6dd07f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -33,6 +33,7 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; +use http::header::AUTHORIZATION; use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; @@ -1417,6 +1418,7 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, + api_key: Option, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, @@ -1810,16 +1812,39 @@ pub async fn run( let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes - let base_routes = Router::new() + let mut base_routes = Router::new() .route("/", post(compat_generate)) - .route("/", get(health)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + if let Some(api_key) = api_key { + let mut prefix = "Bearer ".to_string(); + prefix.push_str(&api_key); + + // Leak to allow FnMut + let api_key: &'static str = prefix.leak(); + + let auth = move |headers: HeaderMap, + request: axum::extract::Request, + next: axum::middleware::Next| async move { + match headers.get(AUTHORIZATION) { + Some(token) if token.to_str().to_lowercase() == api_key.to_lowercase() => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + } + }; + + base_routes = base_routes.layer(axum::middleware::from_fn(auth)) + } + let info_routes = Router::new() + .route("/", get(health)) + .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)); @@ -1838,6 +1863,7 @@ pub async fn run( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) + .merge(info_routes) .merge(aws_sagemaker_route); #[cfg(feature = "google")]