diff --git a/router/src/main.rs b/router/src/main.rs index 24c4c14d..bebf0e53 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -74,6 +74,8 @@ struct Args { ngrok_authtoken: Option, #[clap(long, env)] ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + chat_enabled_api: bool, } #[tokio::main] @@ -105,6 +107,7 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + chat_enabled_api, } = args; // Launch Tokio runtime @@ -356,6 +359,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + chat_enabled_api, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index abddf81f..cd8790b0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -710,6 +710,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + chat_enabled_api: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -860,25 +861,32 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; - // Create router - let app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) - // Base routes + // Configure Swagger UI + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()); + + // Define base and health routes + let base_routes = Router::new() .route("/", post(compat_generate)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) - // AWS Sagemaker route - .route("/invocations", post(compat_generate)) - // Base Health route .route("/health", get(health)) - // Inference API health route - .route("/", get(health)) - // AWS Sagemaker health route .route("/ping", get(health)) - // Prometheus metrics route - .route("/metrics", get(metrics)) + .route("/metrics", get(metrics)); + + // Conditional AWS Sagemaker route + let aws_sagemaker_route = if chat_enabled_api { + Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED + } else { + Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise + }; + + // Combine routes and layers + let app = Router::new() + .merge(swagger_ui) + .merge(base_routes) + .merge(aws_sagemaker_route) .layer(Extension(info)) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text))