mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: add mocked http request tests
This commit is contained in:
parent
630800eed3
commit
2358a35485
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2820,6 +2820,7 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
|
@ -43,6 +43,7 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
hf-hub = "0.3.1"
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
tower = "0.4.13"
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
@ -17,6 +17,11 @@ impl ShardedClient {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient with no shards. Used for testing
|
||||
pub fn empty() -> Self {
|
||||
Self { clients: vec![] }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
|
@ -138,6 +138,9 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
|
||||
// useful when testing the router in isolation
|
||||
skip_generation: Option<bool>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
@ -162,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
skip_generation: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,6 +160,15 @@ async fn generate(
|
||||
|
||||
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||
|
||||
// Early return if skip_generation is set
|
||||
if req.parameters.skip_generation.unwrap_or(false) {
|
||||
let response = GenerateResponse {
|
||||
generated_text: req.inputs.clone(),
|
||||
details: None,
|
||||
};
|
||||
return Ok((HeaderMap::new(), Json(response)));
|
||||
}
|
||||
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => {
|
||||
@ -838,3 +847,215 @@ impl From<InferError> for Event {
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::HttpBody;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, Request, StatusCode},
|
||||
};
|
||||
use serde_json::json;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
/// Build the router for testing purposes
|
||||
async fn build_router() -> Router<(), axum::body::Body> {
|
||||
// Set dummy values for testing
|
||||
let validation_workers = 1;
|
||||
let tokenizer = None;
|
||||
let waiting_served_ratio = 1.0;
|
||||
let max_batch_prefill_tokens = 1;
|
||||
let max_batch_total_tokens = 1;
|
||||
let max_concurrent_requests = 1;
|
||||
let max_waiting_tokens = 1;
|
||||
let requires_padding = false;
|
||||
let allow_origin = None;
|
||||
let max_best_of = 1;
|
||||
let max_stop_sequences = 1;
|
||||
let max_input_length = 1024;
|
||||
let max_total_tokens = 2048;
|
||||
let max_top_n_tokens = 5;
|
||||
|
||||
// Create an empty client
|
||||
let shardless_client = ShardedClient::empty();
|
||||
|
||||
// Create validation and inference
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
// Create shard info
|
||||
let shard_info = ShardInfo {
|
||||
dtype: "demo".to_string(),
|
||||
device_type: "none".to_string(),
|
||||
window_size: Some(1),
|
||||
speculate: 0,
|
||||
requires_padding,
|
||||
};
|
||||
|
||||
// Create model info
|
||||
let model_info = HubModelInfo {
|
||||
model_id: "test".to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
|
||||
// Setup extension
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(shardless_client.clone(), generation_health.clone());
|
||||
|
||||
// Build the Infer struct with the dummy values
|
||||
let infer = Infer::new(
|
||||
shardless_client,
|
||||
validation,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
);
|
||||
|
||||
// CORS layer
|
||||
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
|
||||
let cors_layer = CorsLayer::new()
|
||||
.allow_methods([Method::GET, Method::POST])
|
||||
.allow_headers([http::header::CONTENT_TYPE])
|
||||
.allow_origin(allow_origin);
|
||||
|
||||
// Endpoint info
|
||||
let info = Info {
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
validation_workers,
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
};
|
||||
|
||||
let compat_return_full_text = true;
|
||||
|
||||
// Create router
|
||||
let app: Router<(), Body> = Router::new()
|
||||
// removed the swagger ui for testing
|
||||
// Base routes
|
||||
.route("/", post(compat_generate))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
// AWS Sagemaker route
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
.route("/health", get(health))
|
||||
// Inference API health route
|
||||
.route("/", get(health))
|
||||
// AWS Sagemaker health route
|
||||
.route("/ping", get(health))
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
// removed the prometheus layer for testing
|
||||
.layer(OtelAxumLayer::default())
|
||||
.layer(cors_layer);
|
||||
|
||||
app
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_echo_inputs_when_skip_generation() {
|
||||
let app = build_router().await;
|
||||
|
||||
let request_body = json!({
|
||||
"inputs": "Hello world!",
|
||||
"parameters": {
|
||||
"stream": false,
|
||||
// skip generation is needed for testing to avoid
|
||||
// requests to non-existing client shards
|
||||
"skip_generation": true
|
||||
}
|
||||
});
|
||||
// `Router` implements `tower::Service<Request<Body>>` so we can
|
||||
// call it like any tower service, no need to run an HTTP server.
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/generate")
|
||||
.method(Method::POST)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(axum::body::Body::from(request_body.to_string()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = response.into_body().collect().await.unwrap().to_bytes();
|
||||
let utf8_body = std::str::from_utf8(&body[..]).unwrap();
|
||||
|
||||
let expected_response_body = json!({
|
||||
"generated_text": "Hello world!"
|
||||
});
|
||||
assert_eq!(utf8_body, expected_response_body.to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_return_json_error_on_empty_inputs() {
|
||||
let app = build_router().await;
|
||||
|
||||
let request_body = json!({
|
||||
"inputs": "",
|
||||
"parameters": {
|
||||
"stream": false,
|
||||
/* we do not need to skip_generation here because the validation will fail when trying to generate */
|
||||
}
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/generate")
|
||||
.method(Method::POST)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(axum::body::Body::from(request_body.to_string()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||
|
||||
let body = response.into_body().collect().await.unwrap().to_bytes();
|
||||
let utf8_body = std::str::from_utf8(&body[..]).unwrap();
|
||||
|
||||
let expected_response_body = json!({
|
||||
"error":"Input validation error: `inputs` cannot be empty",
|
||||
"error_type":"validation"
|
||||
});
|
||||
assert_eq!(utf8_body, expected_response_body.to_string());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user