mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Add API_Key for Auth and conditionally add authorisation for non info/health endpoints.
This commit is contained in:
parent
fb98ab273f
commit
b3e21ed42e
@ -349,6 +349,12 @@ Options:
|
|||||||
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
||||||
[env: CORS_ALLOW_ORIGIN=]
|
[env: CORS_ALLOW_ORIGIN=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## API_KEY
|
||||||
|
```shell
|
||||||
|
--api-key <API_KEY>
|
||||||
|
[env: API_KEY=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## WATERMARK_GAMMA
|
## WATERMARK_GAMMA
|
||||||
```shell
|
```shell
|
||||||
|
@ -418,6 +418,10 @@ struct Args {
|
|||||||
|
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Vec<String>,
|
cors_allow_origin: Vec<String>,
|
||||||
|
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
watermark_gamma: Option<f32>,
|
watermark_gamma: Option<f32>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -1244,6 +1248,11 @@ fn spawn_webserver(
|
|||||||
router_args.push(origin);
|
router_args.push(origin);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenTelemetry
|
||||||
|
if let Some(api_key) = args.api_key{
|
||||||
|
router_args.push("--api-key".to_string());
|
||||||
|
router_args.push(api_key);
|
||||||
|
}
|
||||||
// Ngrok
|
// Ngrok
|
||||||
if args.ngrok {
|
if args.ngrok {
|
||||||
router_args.push("--ngrok".to_string());
|
router_args.push("--ngrok".to_string());
|
||||||
|
@ -72,6 +72,8 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
@ -113,6 +115,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
otlp_service_name,
|
otlp_service_name,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
api_key,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
@ -379,6 +382,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
validation_workers,
|
validation_workers,
|
||||||
addr,
|
addr,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
api_key,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
|
@ -31,6 +31,7 @@ use axum::response::{IntoResponse, Response};
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||||
|
use http::header::AUTHORIZATION;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
@ -1419,6 +1420,7 @@ pub async fn run(
|
|||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
|
api_key: Option<String>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
@ -1793,16 +1795,39 @@ pub async fn run(
|
|||||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||||
|
|
||||||
// Define base and health routes
|
// Define base and health routes
|
||||||
let base_routes = Router::new()
|
let mut base_routes = Router::new()
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate))
|
||||||
.route("/", get(health))
|
|
||||||
.route("/info", get(get_model_info))
|
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
.route("/tokenize", post(tokenize))
|
.route("/tokenize", post(tokenize));
|
||||||
|
|
||||||
|
if let Some(api_key) = api_key {
|
||||||
|
let mut prefix = "Bearer ".to_string();
|
||||||
|
prefix.push_str(&api_key);
|
||||||
|
|
||||||
|
// Leak to allow FnMut
|
||||||
|
let api_key: &'static str = prefix.leak();
|
||||||
|
|
||||||
|
let auth = move |headers: HeaderMap,
|
||||||
|
request: axum::extract::Request,
|
||||||
|
next: axum::middleware::Next| async move {
|
||||||
|
match headers.get(AUTHORIZATION) {
|
||||||
|
Some(token) if token == api_key => {
|
||||||
|
let response = next.run(request).await;
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
_ => Err(StatusCode::UNAUTHORIZED),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
base_routes = base_routes.layer(axum::middleware::from_fn(auth))
|
||||||
|
}
|
||||||
|
let health_routes = Router::new()
|
||||||
|
.route("/", get(health))
|
||||||
|
.route("/info", get(get_model_info))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
.route("/metrics", get(metrics));
|
.route("/metrics", get(metrics));
|
||||||
@ -1821,6 +1846,7 @@ pub async fn run(
|
|||||||
let mut app = Router::new()
|
let mut app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
.merge(base_routes)
|
.merge(base_routes)
|
||||||
|
.merge(health_routes)
|
||||||
.merge(aws_sagemaker_route);
|
.merge(aws_sagemaker_route);
|
||||||
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
|
Loading…
Reference in New Issue
Block a user