diff --git a/Cargo.lock b/Cargo.lock index 2284ef84..ea6149d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2275,6 +2275,7 @@ dependencies = [ "tokenizers", "tokio", "tokio-stream", + "tower-http", "tracing", "tracing-opentelemetry", "tracing-subscriber", diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f7ff1cca..ac118566 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -53,6 +53,8 @@ struct Args { json_output: bool, #[clap(long, env)] otlp_endpoint: Option, + #[clap(long, env)] + cors_allow_origin: Vec, } fn main() -> ExitCode { @@ -85,6 +87,7 @@ fn main() -> ExitCode { disable_custom_kernels, json_output, otlp_endpoint, + cors_allow_origin, } = args; // Signal handler @@ -320,6 +323,12 @@ fn main() -> ExitCode { argv.push(otlp_endpoint); } + // CORS origins + for origin in cors_allow_origin.into_iter() { + argv.push("--cors-allow-origin".to_string()); + argv.push(origin); + } + let mut webserver = match Popen::create( &argv, PopenConfig { diff --git a/router/Cargo.toml b/router/Cargo.toml index 156adad7..97a88d55 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -32,6 +32,7 @@ thiserror = "1.0.38" tokenizers = "0.13.2" tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.11" +tower-http = { version = "0.3.5", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.18.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } diff --git a/router/src/main.rs b/router/src/main.rs index 5ababa4b..f1cf09a0 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,4 +1,5 @@ /// Text Generation Inference webserver entrypoint +use axum::http::HeaderValue; use clap::Parser; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; @@ -10,6 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use text_generation_client::ShardedClient; use text_generation_router::server; use tokenizers::Tokenizer; +use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; @@ -42,6 +44,8 @@ struct Args { json_output: bool, #[clap(long, env)] otlp_endpoint: Option, + #[clap(long, env)] + cors_allow_origin: Option>, } fn main() -> Result<(), std::io::Error> { @@ -61,12 +65,24 @@ fn main() -> Result<(), std::io::Error> { validation_workers, json_output, otlp_endpoint, + cors_allow_origin, } = args; if validation_workers == 0 { panic!("validation_workers must be > 0"); } + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let cors_allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); + // Download and instantiate tokenizer // This will only be used to validate payloads // @@ -107,6 +123,7 @@ fn main() -> Result<(), std::io::Error> { tokenizer, validation_workers, addr, + cors_allow_origin, ) .await; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 48affa46..6acbbffa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,11 +5,11 @@ use crate::{ Infer, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; -use axum::http::{HeaderMap, StatusCode}; +use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::IntoResponse; use axum::routing::{get, post}; -use axum::{Json, Router}; +use axum::{http, Json, Router}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use futures::Stream; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; @@ -20,6 +20,7 @@ use tokenizers::Tokenizer; use tokio::signal; use tokio::time::Instant; use tokio_stream::StreamExt; +use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -334,6 +335,7 @@ pub async fn run( tokenizer: Tokenizer, validation_workers: usize, addr: SocketAddr, + allow_origin: Option, ) { // OpenAPI documentation #[derive(OpenApi)] @@ -391,6 +393,13 @@ pub async fn run( .install_recorder() .expect("failed to install metrics recorder"); + // CORS layer + let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); + let cors_layer = CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_headers([http::header::CONTENT_TYPE]) + .allow_origin(allow_origin); + // Create router let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) @@ -402,7 +411,8 @@ pub async fn run( .layer(Extension(infer)) .route("/metrics", get(metrics)) .layer(Extension(prom_handle)) - .layer(opentelemetry_tracing_layer()); + .layer(opentelemetry_tracing_layer()) + .layer(cors_layer); // Run server axum::Server::bind(&addr)