diff --git a/Cargo.lock b/Cargo.lock index 27c345b3..234e2bfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2820,6 +2820,7 @@ dependencies = [ "tokenizers", "tokio", "tokio-stream", + "tower", "tower-http", "tracing", "tracing-opentelemetry", diff --git a/router/Cargo.toml b/router/Cargo.toml index 55af635a..e4c58e0c 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -43,6 +43,7 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } hf-hub = "0.3.1" init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +tower = "0.4.13" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6c5da3c7..cadfb032 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -17,6 +17,11 @@ impl ShardedClient { Self { clients } } + /// Create a new ShardedClient with no shards. Used for testing + pub fn empty() -> Self { + Self { clients: vec![] } + } + /// Create a new ShardedClient from a master client. The master client will communicate with /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. async fn from_master_client(mut master_client: Client) -> Result { diff --git a/router/src/lib.rs b/router/src/lib.rs index 898fcd04..cb75265f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -138,6 +138,9 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, + + // useful when testing the router in isolation + skip_generation: Option, } fn default_max_new_tokens() -> Option { @@ -162,6 +165,7 @@ fn default_parameters() -> GenerateParameters { decoder_input_details: false, seed: None, top_n_tokens: None, + skip_generation: None, } } diff --git a/router/src/server.rs b/router/src/server.rs index fe1b8309..c7026987 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -160,6 +160,15 @@ async fn generate( let details: bool = req.parameters.details || req.parameters.decoder_input_details; + // Early return if skip_generation is set + if req.parameters.skip_generation.unwrap_or(false) { + let response = GenerateResponse { + generated_text: req.inputs.clone(), + details: None, + }; + return Ok((HeaderMap::new(), Json(response))); + } + // Inference let (response, best_of_responses) = match req.parameters.best_of { Some(best_of) if best_of > 1 => { @@ -838,3 +847,215 @@ impl From for Event { .unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::HttpBody; + use axum::{ + body::Body, + http::{self, Request, StatusCode}, + }; + use serde_json::json; + use tower::util::ServiceExt; + + /// Build the router for testing purposes + async fn build_router() -> Router<(), axum::body::Body> { + // Set dummy values for testing + let validation_workers = 1; + let tokenizer = None; + let waiting_served_ratio = 1.0; + let max_batch_prefill_tokens = 1; + let max_batch_total_tokens = 1; + let max_concurrent_requests = 1; + let max_waiting_tokens = 1; + let requires_padding = false; + let allow_origin = None; + let max_best_of = 1; + let max_stop_sequences = 1; + let max_input_length = 1024; + let max_total_tokens = 2048; + let max_top_n_tokens = 5; + + // Create an empty client + let shardless_client = ShardedClient::empty(); + + // Create validation and inference + let validation = Validation::new( + validation_workers, + tokenizer, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_length, + max_total_tokens, + ); + + // Create shard info + let shard_info = ShardInfo { + dtype: "demo".to_string(), + device_type: "none".to_string(), + window_size: Some(1), + speculate: 0, + requires_padding, + }; + + // Create model info + let model_info = HubModelInfo { + model_id: "test".to_string(), + sha: None, + pipeline_tag: None, + }; + + // Setup extension + let generation_health = Arc::new(AtomicBool::new(false)); + let health_ext = Health::new(shardless_client.clone(), generation_health.clone()); + + // Build the Infer struct with the dummy values + let infer = Infer::new( + shardless_client, + validation, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_concurrent_requests, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + generation_health, + ); + + // CORS layer + let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); + let cors_layer = CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_headers([http::header::CONTENT_TYPE]) + .allow_origin(allow_origin); + + // Endpoint info + let info = Info { + model_id: model_info.model_id, + model_sha: model_info.sha, + model_dtype: shard_info.dtype, + model_device_type: shard_info.device_type, + model_pipeline_tag: model_info.pipeline_tag, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + validation_workers, + version: env!("CARGO_PKG_VERSION"), + sha: option_env!("VERGEN_GIT_SHA"), + docker_label: option_env!("DOCKER_LABEL"), + }; + + let compat_return_full_text = true; + + // Create router + let app: Router<(), Body> = Router::new() + // removed the swagger ui for testing + // Base routes + .route("/", post(compat_generate)) + .route("/info", get(get_model_info)) + .route("/generate", post(generate)) + .route("/generate_stream", post(generate_stream)) + // AWS Sagemaker route + .route("/invocations", post(compat_generate)) + // Base Health route + .route("/health", get(health)) + // Inference API health route + .route("/", get(health)) + // AWS Sagemaker health route + .route("/ping", get(health)) + // Prometheus metrics route + .route("/metrics", get(metrics)) + .layer(Extension(info)) + .layer(Extension(health_ext.clone())) + .layer(Extension(compat_return_full_text)) + .layer(Extension(infer)) + // removed the prometheus layer for testing + .layer(OtelAxumLayer::default()) + .layer(cors_layer); + + app + } + + #[tokio::test] + async fn test_echo_inputs_when_skip_generation() { + let app = build_router().await; + + let request_body = json!({ + "inputs": "Hello world!", + "parameters": { + "stream": false, + // skip generation is needed for testing to avoid + // requests to non-existing client shards + "skip_generation": true + } + }); + // `Router` implements `tower::Service>` so we can + // call it like any tower service, no need to run an HTTP server. + let response = app + .oneshot( + Request::builder() + .uri("/generate") + .method(Method::POST) + .header(http::header::CONTENT_TYPE, "application/json") + .body(axum::body::Body::from(request_body.to_string())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let utf8_body = std::str::from_utf8(&body[..]).unwrap(); + + let expected_response_body = json!({ + "generated_text": "Hello world!" + }); + assert_eq!(utf8_body, expected_response_body.to_string()); + } + + #[tokio::test] + async fn test_return_json_error_on_empty_inputs() { + let app = build_router().await; + + let request_body = json!({ + "inputs": "", + "parameters": { + "stream": false, + /* we do not need to skip_generation here because the validation will fail when trying to generate */ + } + }); + + let response = app + .oneshot( + Request::builder() + .uri("/generate") + .method(Method::POST) + .header(http::header::CONTENT_TYPE, "application/json") + .body(axum::body::Body::from(request_body.to_string())) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let utf8_body = std::str::from_utf8(&body[..]).unwrap(); + + let expected_response_body = json!({ + "error":"Input validation error: `inputs` cannot be empty", + "error_type":"validation" + }); + assert_eq!(utf8_body, expected_response_body.to_string()); + } +}