From 90541fba071c83f17e021f92f48de22a059e9619 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 18 Jan 2024 09:38:40 -0500 Subject: [PATCH] feat: conditionally toggle chat --- router/src/server.rs | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index fe1827c4..624e1c5a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -856,25 +856,32 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; - // Create router - let app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) - // Base routes + // Configure Swagger UI + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()); + + // Define base and health routes + let base_routes = Router::new() .route("/", post(compat_generate)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) - // AWS Sagemaker route - .route("/invocations", post(compat_generate)) - // Base Health route .route("/health", get(health)) - // Inference API health route - .route("/", get(health)) - // AWS Sagemaker health route .route("/ping", get(health)) - // Prometheus metrics route - .route("/metrics", get(metrics)) + .route("/metrics", get(metrics)); + + // Conditional AWS Sagemaker route + let aws_sagemaker_route = if std::env::var("OAI_ENABLED").map_or(false, |val| val == "true") { + Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED + } else { + Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise + }; + + // Combine routes and layers + let app = Router::new() + .merge(swagger_ui) + .merge(base_routes) + .merge(aws_sagemaker_route) .layer(Extension(info)) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text))