From b3e21ed42e81049986a817b8f2c3e7d0470149fb Mon Sep 17 00:00:00 2001 From: Kevin Duffy Date: Fri, 28 Jun 2024 18:41:21 +0100 Subject: [PATCH 1/5] Add API_Key for Auth and conditionally add authorisation for non info/health endpoints. --- docs/source/basic_tutorials/launcher.md | 6 +++++ launcher/src/main.rs | 9 +++++++ router/src/main.rs | 4 +++ router/src/server.rs | 34 ++++++++++++++++++++++--- 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f..dbfaeab0 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -349,6 +349,12 @@ Options: --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] +``` +## API_KEY +```shell + --api-key + [env: API_KEY=] + ``` ## WATERMARK_GAMMA ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 816fa5f3..60d3a38f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -418,6 +418,10 @@ struct Args { #[clap(long, env)] cors_allow_origin: Vec, + + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] watermark_gamma: Option, #[clap(long, env)] @@ -1244,6 +1248,11 @@ fn spawn_webserver( router_args.push(origin); } + // OpenTelemetry + if let Some(api_key) = args.api_key{ + router_args.push("--api-key".to_string()); + router_args.push(api_key); + } // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index 1e8093d8..f5dc27f7 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -72,6 +72,8 @@ struct Args { #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] + api_key: Option, + #[clap(long, env)] ngrok: bool, #[clap(long, env)] ngrok_authtoken: Option, @@ -113,6 +115,7 @@ async fn main() -> Result<(), RouterError> { otlp_endpoint, otlp_service_name, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, @@ -379,6 +382,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, addr, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, diff --git a/router/src/server.rs b/router/src/server.rs index 0cb08d4e..499b2abe 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -31,6 +31,7 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; +use http::header::AUTHORIZATION; use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; @@ -1419,6 +1420,7 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, + api_key: Option, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, @@ -1793,16 +1795,39 @@ pub async fn run( let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes - let base_routes = Router::new() + let mut base_routes = Router::new() .route("/", post(compat_generate)) - .route("/", get(health)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + if let Some(api_key) = api_key { + let mut prefix = "Bearer ".to_string(); + prefix.push_str(&api_key); + + // Leak to allow FnMut + let api_key: &'static str = prefix.leak(); + + let auth = move |headers: HeaderMap, + request: axum::extract::Request, + next: axum::middleware::Next| async move { + match headers.get(AUTHORIZATION) { + Some(token) if token == api_key => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + } + }; + + base_routes = base_routes.layer(axum::middleware::from_fn(auth)) + } + let health_routes = Router::new() + .route("/", get(health)) + .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)); @@ -1821,6 +1846,7 @@ pub async fn run( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) + .merge(health_routes) .merge(aws_sagemaker_route); #[cfg(feature = "google")] From 45da4460a39dacacd1d4e80f55224520e9a4b167 Mon Sep 17 00:00:00 2001 From: Kevin Duffy Date: Fri, 28 Jun 2024 18:46:07 +0100 Subject: [PATCH 2/5] change name to info routes --- router/src/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 499b2abe..8768c206 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1825,7 +1825,7 @@ pub async fn run( base_routes = base_routes.layer(axum::middleware::from_fn(auth)) } - let health_routes = Router::new() + let info_routes = Router::new() .route("/", get(health)) .route("/info", get(get_model_info)) .route("/health", get(health)) @@ -1846,7 +1846,7 @@ pub async fn run( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) - .merge(health_routes) + .merge(info_routes) .merge(aws_sagemaker_route); #[cfg(feature = "google")] From 19e63ffccca3b97b7b5dfd3b4ce4418fd32255b2 Mon Sep 17 00:00:00 2001 From: KevinDuffy94 Date: Wed, 24 Jul 2024 10:27:27 -0400 Subject: [PATCH 3/5] Fix comment --- launcher/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 60d3a38f..3d3bc4d6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1242,7 +1242,7 @@ fn spawn_webserver( router_args.push("--otlp-service-name".to_string()); router_args.push(otlp_service_name); - // CORS origins + // API Key for origin in args.cors_allow_origin.into_iter() { router_args.push("--cors-allow-origin".to_string()); router_args.push(origin); From 9f9997b5d441dd3a78e7525c9f2311b41076ca6d Mon Sep 17 00:00:00 2001 From: KevinDuffy94 Date: Wed, 24 Jul 2024 10:32:48 -0400 Subject: [PATCH 4/5] convert strings to lowercase for case insensitive comparison --- router/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/server.rs b/router/src/server.rs index 8768c206..516a2ef6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1815,7 +1815,7 @@ pub async fn run( request: axum::extract::Request, next: axum::middleware::Next| async move { match headers.get(AUTHORIZATION) { - Some(token) if token == api_key => { + Some(token) if token.to_lowercase() == api_key.to_lowercase() => { let response = next.run(request).await; Ok(response) } From c5a982de822506b29452aecc82d7bbc2de67db41 Mon Sep 17 00:00:00 2001 From: KevinDuffy94 Date: Thu, 25 Jul 2024 09:16:24 -0400 Subject: [PATCH 5/5] convert header to string --- router/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/server.rs b/router/src/server.rs index 516a2ef6..fa2ba001 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1815,7 +1815,7 @@ pub async fn run( request: axum::extract::Request, next: axum::middleware::Next| async move { match headers.get(AUTHORIZATION) { - Some(token) if token.to_lowercase() == api_key.to_lowercase() => { + Some(token) if token.to_str().to_lowercase() == api_key.to_lowercase() => { let response = next.run(request).await; Ok(response) }