From 215afc15f04407cc32cf8a928f4691a9f07f2fde Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 18 Jan 2024 11:03:05 -0500 Subject: [PATCH] fix: prefer env value from clap for better defaults --- router/src/main.rs | 4 ++++ router/src/server.rs | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/router/src/main.rs b/router/src/main.rs index f5d44305..bf987eb6 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -71,6 +71,8 @@ struct Args { ngrok_authtoken: Option, #[clap(long, env)] ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + chat_enabled_api: bool, } #[tokio::main] @@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + chat_enabled_api, } = args; // Launch Tokio runtime @@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + chat_enabled_api, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 624e1c5a..69fb45df 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -708,6 +708,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + chat_enabled_api: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -871,7 +872,7 @@ pub async fn run( .route("/metrics", get(metrics)); // Conditional AWS Sagemaker route - let aws_sagemaker_route = if std::env::var("OAI_ENABLED").map_or(false, |val| val == "true") { + let aws_sagemaker_route = if chat_enabled_api { Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED } else { Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise