diff --git a/router/src/main.rs b/router/src/main.rs index f5d44305..bf987eb6 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -71,6 +71,8 @@ struct Args { ngrok_authtoken: Option, #[clap(long, env)] ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + chat_enabled_api: bool, } #[tokio::main] @@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + chat_enabled_api, } = args; // Launch Tokio runtime @@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + chat_enabled_api, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 624e1c5a..69fb45df 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -708,6 +708,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + chat_enabled_api: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -871,7 +872,7 @@ pub async fn run( .route("/metrics", get(metrics)); // Conditional AWS Sagemaker route - let aws_sagemaker_route = if std::env::var("OAI_ENABLED").map_or(false, |val| val == "true") { + let aws_sagemaker_route = if chat_enabled_api { Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED } else { Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise