mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: add mocked http request tests
This commit is contained in:
parent
630800eed3
commit
2358a35485
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2820,6 +2820,7 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
|
@ -43,6 +43,7 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
|||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
hf-hub = "0.3.1"
|
hf-hub = "0.3.1"
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
|
tower = "0.4.13"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
|
@ -17,6 +17,11 @@ impl ShardedClient {
|
|||||||
Self { clients }
|
Self { clients }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient with no shards. Used for testing
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self { clients: vec![] }
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
@ -138,6 +138,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
|
|
||||||
|
// useful when testing the router in isolation
|
||||||
|
skip_generation: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
@ -162,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
skip_generation: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,6 +160,15 @@ async fn generate(
|
|||||||
|
|
||||||
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||||
|
|
||||||
|
// Early return if skip_generation is set
|
||||||
|
if req.parameters.skip_generation.unwrap_or(false) {
|
||||||
|
let response = GenerateResponse {
|
||||||
|
generated_text: req.inputs.clone(),
|
||||||
|
details: None,
|
||||||
|
};
|
||||||
|
return Ok((HeaderMap::new(), Json(response)));
|
||||||
|
}
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.parameters.best_of {
|
let (response, best_of_responses) = match req.parameters.best_of {
|
||||||
Some(best_of) if best_of > 1 => {
|
Some(best_of) if best_of > 1 => {
|
||||||
@ -838,3 +847,215 @@ impl From<InferError> for Event {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use axum::body::HttpBody;
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
http::{self, Request, StatusCode},
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
use tower::util::ServiceExt;
|
||||||
|
|
||||||
|
/// Build the router for testing purposes
|
||||||
|
async fn build_router() -> Router<(), axum::body::Body> {
|
||||||
|
// Set dummy values for testing
|
||||||
|
let validation_workers = 1;
|
||||||
|
let tokenizer = None;
|
||||||
|
let waiting_served_ratio = 1.0;
|
||||||
|
let max_batch_prefill_tokens = 1;
|
||||||
|
let max_batch_total_tokens = 1;
|
||||||
|
let max_concurrent_requests = 1;
|
||||||
|
let max_waiting_tokens = 1;
|
||||||
|
let requires_padding = false;
|
||||||
|
let allow_origin = None;
|
||||||
|
let max_best_of = 1;
|
||||||
|
let max_stop_sequences = 1;
|
||||||
|
let max_input_length = 1024;
|
||||||
|
let max_total_tokens = 2048;
|
||||||
|
let max_top_n_tokens = 5;
|
||||||
|
|
||||||
|
// Create an empty client
|
||||||
|
let shardless_client = ShardedClient::empty();
|
||||||
|
|
||||||
|
// Create validation and inference
|
||||||
|
let validation = Validation::new(
|
||||||
|
validation_workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create shard info
|
||||||
|
let shard_info = ShardInfo {
|
||||||
|
dtype: "demo".to_string(),
|
||||||
|
device_type: "none".to_string(),
|
||||||
|
window_size: Some(1),
|
||||||
|
speculate: 0,
|
||||||
|
requires_padding,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create model info
|
||||||
|
let model_info = HubModelInfo {
|
||||||
|
model_id: "test".to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Setup extension
|
||||||
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
|
let health_ext = Health::new(shardless_client.clone(), generation_health.clone());
|
||||||
|
|
||||||
|
// Build the Infer struct with the dummy values
|
||||||
|
let infer = Infer::new(
|
||||||
|
shardless_client,
|
||||||
|
validation,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_concurrent_requests,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
generation_health,
|
||||||
|
);
|
||||||
|
|
||||||
|
// CORS layer
|
||||||
|
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
|
||||||
|
let cors_layer = CorsLayer::new()
|
||||||
|
.allow_methods([Method::GET, Method::POST])
|
||||||
|
.allow_headers([http::header::CONTENT_TYPE])
|
||||||
|
.allow_origin(allow_origin);
|
||||||
|
|
||||||
|
// Endpoint info
|
||||||
|
let info = Info {
|
||||||
|
model_id: model_info.model_id,
|
||||||
|
model_sha: model_info.sha,
|
||||||
|
model_dtype: shard_info.dtype,
|
||||||
|
model_device_type: shard_info.device_type,
|
||||||
|
model_pipeline_tag: model_info.pipeline_tag,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
validation_workers,
|
||||||
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
|
sha: option_env!("VERGEN_GIT_SHA"),
|
||||||
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let compat_return_full_text = true;
|
||||||
|
|
||||||
|
// Create router
|
||||||
|
let app: Router<(), Body> = Router::new()
|
||||||
|
// removed the swagger ui for testing
|
||||||
|
// Base routes
|
||||||
|
.route("/", post(compat_generate))
|
||||||
|
.route("/info", get(get_model_info))
|
||||||
|
.route("/generate", post(generate))
|
||||||
|
.route("/generate_stream", post(generate_stream))
|
||||||
|
// AWS Sagemaker route
|
||||||
|
.route("/invocations", post(compat_generate))
|
||||||
|
// Base Health route
|
||||||
|
.route("/health", get(health))
|
||||||
|
// Inference API health route
|
||||||
|
.route("/", get(health))
|
||||||
|
// AWS Sagemaker health route
|
||||||
|
.route("/ping", get(health))
|
||||||
|
// Prometheus metrics route
|
||||||
|
.route("/metrics", get(metrics))
|
||||||
|
.layer(Extension(info))
|
||||||
|
.layer(Extension(health_ext.clone()))
|
||||||
|
.layer(Extension(compat_return_full_text))
|
||||||
|
.layer(Extension(infer))
|
||||||
|
// removed the prometheus layer for testing
|
||||||
|
.layer(OtelAxumLayer::default())
|
||||||
|
.layer(cors_layer);
|
||||||
|
|
||||||
|
app
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_echo_inputs_when_skip_generation() {
|
||||||
|
let app = build_router().await;
|
||||||
|
|
||||||
|
let request_body = json!({
|
||||||
|
"inputs": "Hello world!",
|
||||||
|
"parameters": {
|
||||||
|
"stream": false,
|
||||||
|
// skip generation is needed for testing to avoid
|
||||||
|
// requests to non-existing client shards
|
||||||
|
"skip_generation": true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// `Router` implements `tower::Service<Request<Body>>` so we can
|
||||||
|
// call it like any tower service, no need to run an HTTP server.
|
||||||
|
let response = app
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/generate")
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(http::header::CONTENT_TYPE, "application/json")
|
||||||
|
.body(axum::body::Body::from(request_body.to_string()))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let body = response.into_body().collect().await.unwrap().to_bytes();
|
||||||
|
let utf8_body = std::str::from_utf8(&body[..]).unwrap();
|
||||||
|
|
||||||
|
let expected_response_body = json!({
|
||||||
|
"generated_text": "Hello world!"
|
||||||
|
});
|
||||||
|
assert_eq!(utf8_body, expected_response_body.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_return_json_error_on_empty_inputs() {
|
||||||
|
let app = build_router().await;
|
||||||
|
|
||||||
|
let request_body = json!({
|
||||||
|
"inputs": "",
|
||||||
|
"parameters": {
|
||||||
|
"stream": false,
|
||||||
|
/* we do not need to skip_generation here because the validation will fail when trying to generate */
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/generate")
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(http::header::CONTENT_TYPE, "application/json")
|
||||||
|
.body(axum::body::Body::from(request_body.to_string()))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||||
|
|
||||||
|
let body = response.into_body().collect().await.unwrap().to_bytes();
|
||||||
|
let utf8_body = std::str::from_utf8(&body[..]).unwrap();
|
||||||
|
|
||||||
|
let expected_response_body = json!({
|
||||||
|
"error":"Input validation error: `inputs` cannot be empty",
|
||||||
|
"error_type":"validation"
|
||||||
|
});
|
||||||
|
assert_eq!(utf8_body, expected_response_body.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user