From 5836a1cc69e85862688b24924fd5244f36c84ce0 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 22 Jan 2024 10:29:01 -0500 Subject: [PATCH] feat: conditionally toggle chat on invocations route (#1454) This PR adds support for reading the `OAI_ENABLED` env var which will changes the function called when the `/invocations` is called. If `OAI_ENABLED=true` the `chat_completions` method is used otherwise it defaults to `compat_generate`. example running the router ```bash OAI_ENABLED=true \ cargo run -- \ --tokenizer-name mistralai/Mistral-7B-Instruct-v0.2 ``` example request ```bash curl localhost:3000/invocations \ -X POST \ -d '{ "model": "tgi", "messages": [ { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": false, "max_tokens": 20, "logprobs": true, "seed": 0 }' \ -H 'Content-Type: application/json' | jq ``` **please let me know if any naming changes are needed or if any other routes need similar functionality. --- router/src/main.rs | 4 ++++ router/src/server.rs | 32 ++++++++++++++++++++------------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 24c4c14d..bebf0e53 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -74,6 +74,8 @@ struct Args { ngrok_authtoken: Option, #[clap(long, env)] ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + chat_enabled_api: bool, } #[tokio::main] @@ -105,6 +107,7 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + chat_enabled_api, } = args; // Launch Tokio runtime @@ -356,6 +359,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, + chat_enabled_api, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index abddf81f..cd8790b0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -710,6 +710,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + chat_enabled_api: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -860,25 +861,32 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; - // Create router - let app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) - // Base routes + // Configure Swagger UI + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()); + + // Define base and health routes + let base_routes = Router::new() .route("/", post(compat_generate)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) - // AWS Sagemaker route - .route("/invocations", post(compat_generate)) - // Base Health route .route("/health", get(health)) - // Inference API health route - .route("/", get(health)) - // AWS Sagemaker health route .route("/ping", get(health)) - // Prometheus metrics route - .route("/metrics", get(metrics)) + .route("/metrics", get(metrics)); + + // Conditional AWS Sagemaker route + let aws_sagemaker_route = if chat_enabled_api { + Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED + } else { + Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise + }; + + // Combine routes and layers + let app = Router::new() + .merge(swagger_ui) + .merge(base_routes) + .merge(aws_sagemaker_route) .layer(Extension(info)) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text))