feat: add mocked http request tests

This commit is contained in:
drbh 2024-01-03 16:13:50 -05:00
parent 630800eed3
commit 2358a35485
5 changed files with 232 additions and 0 deletions

1
Cargo.lock generated
View File

@ -2820,6 +2820,7 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",

View File

@ -43,6 +43,7 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
hf-hub = "0.3.1" hf-hub = "0.3.1"
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
tower = "0.4.13"
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -17,6 +17,11 @@ impl ShardedClient {
Self { clients } Self { clients }
} }
/// Create a new ShardedClient with no shards. Used for testing
pub fn empty() -> Self {
Self { clients: vec![] }
}
/// Create a new ShardedClient from a master client. The master client will communicate with /// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> { async fn from_master_client(mut master_client: Client) -> Result<Self> {

View File

@ -138,6 +138,9 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
// useful when testing the router in isolation
skip_generation: Option<bool>,
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
@ -162,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
skip_generation: None,
} }
} }

View File

@ -160,6 +160,15 @@ async fn generate(
let details: bool = req.parameters.details || req.parameters.decoder_input_details; let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Early return if skip_generation is set
if req.parameters.skip_generation.unwrap_or(false) {
let response = GenerateResponse {
generated_text: req.inputs.clone(),
details: None,
};
return Ok((HeaderMap::new(), Json(response)));
}
// Inference // Inference
let (response, best_of_responses) = match req.parameters.best_of { let (response, best_of_responses) = match req.parameters.best_of {
Some(best_of) if best_of > 1 => { Some(best_of) if best_of > 1 => {
@ -838,3 +847,215 @@ impl From<InferError> for Event {
.unwrap() .unwrap()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use axum::body::HttpBody;
use axum::{
body::Body,
http::{self, Request, StatusCode},
};
use serde_json::json;
use tower::util::ServiceExt;
/// Build the router for testing purposes
async fn build_router() -> Router<(), axum::body::Body> {
// Set dummy values for testing
let validation_workers = 1;
let tokenizer = None;
let waiting_served_ratio = 1.0;
let max_batch_prefill_tokens = 1;
let max_batch_total_tokens = 1;
let max_concurrent_requests = 1;
let max_waiting_tokens = 1;
let requires_padding = false;
let allow_origin = None;
let max_best_of = 1;
let max_stop_sequences = 1;
let max_input_length = 1024;
let max_total_tokens = 2048;
let max_top_n_tokens = 5;
// Create an empty client
let shardless_client = ShardedClient::empty();
// Create validation and inference
let validation = Validation::new(
validation_workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
// Create shard info
let shard_info = ShardInfo {
dtype: "demo".to_string(),
device_type: "none".to_string(),
window_size: Some(1),
speculate: 0,
requires_padding,
};
// Create model info
let model_info = HubModelInfo {
model_id: "test".to_string(),
sha: None,
pipeline_tag: None,
};
// Setup extension
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(shardless_client.clone(), generation_health.clone());
// Build the Infer struct with the dummy values
let infer = Infer::new(
shardless_client,
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
);
// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);
// Endpoint info
let info = Info {
model_id: model_info.model_id,
model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
validation_workers,
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
docker_label: option_env!("DOCKER_LABEL"),
};
let compat_return_full_text = true;
// Create router
let app: Router<(), Body> = Router::new()
// removed the swagger ui for testing
// Base routes
.route("/", post(compat_generate))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(info))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
// removed the prometheus layer for testing
.layer(OtelAxumLayer::default())
.layer(cors_layer);
app
}
#[tokio::test]
async fn test_echo_inputs_when_skip_generation() {
let app = build_router().await;
let request_body = json!({
"inputs": "Hello world!",
"parameters": {
"stream": false,
// skip generation is needed for testing to avoid
// requests to non-existing client shards
"skip_generation": true
}
});
// `Router` implements `tower::Service<Request<Body>>` so we can
// call it like any tower service, no need to run an HTTP server.
let response = app
.oneshot(
Request::builder()
.uri("/generate")
.method(Method::POST)
.header(http::header::CONTENT_TYPE, "application/json")
.body(axum::body::Body::from(request_body.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
let utf8_body = std::str::from_utf8(&body[..]).unwrap();
let expected_response_body = json!({
"generated_text": "Hello world!"
});
assert_eq!(utf8_body, expected_response_body.to_string());
}
#[tokio::test]
async fn test_return_json_error_on_empty_inputs() {
let app = build_router().await;
let request_body = json!({
"inputs": "",
"parameters": {
"stream": false,
/* we do not need to skip_generation here because the validation will fail when trying to generate */
}
});
let response = app
.oneshot(
Request::builder()
.uri("/generate")
.method(Method::POST)
.header(http::header::CONTENT_TYPE, "application/json")
.body(axum::body::Body::from(request_body.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body = response.into_body().collect().await.unwrap().to_bytes();
let utf8_body = std::str::from_utf8(&body[..]).unwrap();
let expected_response_body = json!({
"error":"Input validation error: `inputs` cannot be empty",
"error_type":"validation"
});
assert_eq!(utf8_body, expected_response_body.to_string());
}
}