diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9fbc2d84..819cc242 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1276,7 +1276,7 @@ fn spawn_webserver( } // OpenTelemetry - if let Some(api_key) = args.api_key{ + if let Some(api_key) = args.api_key { router_args.push("--api-key".to_string()); router_args.push(api_key); } diff --git a/router/src/server.rs b/router/src/server.rs index 3f6dd07f..89e6402b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -33,11 +33,11 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; -use http::header::AUTHORIZATION; use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; @@ -1832,10 +1832,13 @@ pub async fn run( request: axum::extract::Request, next: axum::middleware::Next| async move { match headers.get(AUTHORIZATION) { - Some(token) if token.to_str().to_lowercase() == api_key.to_lowercase() => { - let response = next.run(request).await; - Ok(response) - } + Some(token) => match token.to_str() { + Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + }, _ => Err(StatusCode::UNAUTHORIZED), } };