Add API_Key for Auth and conditionally add authorisation for non info/health endpoints.

This commit is contained in:
Kevin Duffy 2024-06-28 18:41:21 +01:00
parent fb98ab273f
commit b3e21ed42e
4 changed files with 49 additions and 4 deletions

View File

@ -349,6 +349,12 @@ Options:
--cors-allow-origin <CORS_ALLOW_ORIGIN> --cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=] [env: CORS_ALLOW_ORIGIN=]
```
## API_KEY
```shell
--api-key <API_KEY>
[env: API_KEY=]
``` ```
## WATERMARK_GAMMA ## WATERMARK_GAMMA
```shell ```shell

View File

@ -418,6 +418,10 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Vec<String>, cors_allow_origin: Vec<String>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)] #[clap(long, env)]
watermark_gamma: Option<f32>, watermark_gamma: Option<f32>,
#[clap(long, env)] #[clap(long, env)]
@ -1244,6 +1248,11 @@ fn spawn_webserver(
router_args.push(origin); router_args.push(origin);
} }
// OpenTelemetry
if let Some(api_key) = args.api_key{
router_args.push("--api-key".to_string());
router_args.push(api_key);
}
// Ngrok // Ngrok
if args.ngrok { if args.ngrok {
router_args.push("--ngrok".to_string()); router_args.push("--ngrok".to_string());

View File

@ -72,6 +72,8 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Option<Vec<String>>, cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)] #[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
ngrok: bool, ngrok: bool,
#[clap(long, env)] #[clap(long, env)]
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
@ -113,6 +115,7 @@ async fn main() -> Result<(), RouterError> {
otlp_endpoint, otlp_endpoint,
otlp_service_name, otlp_service_name,
cors_allow_origin, cors_allow_origin,
api_key,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
@ -379,6 +382,7 @@ async fn main() -> Result<(), RouterError> {
validation_workers, validation_workers,
addr, addr,
cors_allow_origin, cors_allow_origin,
api_key,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,

View File

@ -31,6 +31,7 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{http, Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use http::header::AUTHORIZATION;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream; use futures::Stream;
@ -1419,6 +1420,7 @@ pub async fn run(
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
api_key: Option<String>,
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
@ -1793,16 +1795,39 @@ pub async fn run(
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
// Define base and health routes // Define base and health routes
let base_routes = Router::new() let mut base_routes = Router::new()
.route("/", post(compat_generate)) .route("/", post(compat_generate))
.route("/", get(health))
.route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions)) .route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility)) .route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize)) .route("/tokenize", post(tokenize));
if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);
// Leak to allow FnMut
let api_key: &'static str = prefix.leak();
let auth = move |headers: HeaderMap,
request: axum::extract::Request,
next: axum::middleware::Next| async move {
match headers.get(AUTHORIZATION) {
Some(token) if token == api_key => {
let response = next.run(request).await;
Ok(response)
}
_ => Err(StatusCode::UNAUTHORIZED),
}
};
base_routes = base_routes.layer(axum::middleware::from_fn(auth))
}
let health_routes = Router::new()
.route("/", get(health))
.route("/info", get(get_model_info))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)); .route("/metrics", get(metrics));
@ -1821,6 +1846,7 @@ pub async fn run(
let mut app = Router::new() let mut app = Router::new()
.merge(swagger_ui) .merge(swagger_ui)
.merge(base_routes) .merge(base_routes)
.merge(health_routes)
.merge(aws_sagemaker_route); .merge(aws_sagemaker_route);
#[cfg(feature = "google")] #[cfg(feature = "google")]