mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Tmp.
This commit is contained in:
parent
45344244cf
commit
1c97d7b0c0
25
Cargo.lock
generated
25
Cargo.lock
generated
@ -183,6 +183,21 @@ dependencies = [
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-test-server"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57f8f8627d32fe7e2c36b33de0e87dcdee4d6ac8619b9b892e5cc299ea4eed52"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"hyper",
|
||||
"portpicker",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-tracing-opentelemetry"
|
||||
version = "0.10.0"
|
||||
@ -1736,6 +1751,15 @@ version = "0.3.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
|
||||
|
||||
[[package]]
|
||||
name = "portpicker"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9"
|
||||
dependencies = [
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
@ -2407,6 +2431,7 @@ version = "0.6.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
"axum-test-server",
|
||||
"axum-tracing-opentelemetry",
|
||||
"clap",
|
||||
"flume",
|
||||
|
@ -42,3 +42,6 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||
|
||||
[dev-dependencies]
|
||||
axum-test-server = "2.0.0"
|
||||
|
@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse {
|
||||
pub error: String,
|
||||
pub error_type: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests{
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer{
|
||||
if !std::path::Path::new("tokenizer.json").exists(){
|
||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||
file.write_all(&content).unwrap();
|
||||
}
|
||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -741,3 +741,63 @@ impl From<InferError> for Event {
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests{
|
||||
use super::*;
|
||||
use crate::tests::get_tokenizer;
|
||||
use axum_test_server::TestServer;
|
||||
use crate::default_parameters;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health(){
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
best_of: Some(2),
|
||||
do_sample: false,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
Err(ValidationError::BestOfSampling) => (),
|
||||
_ => panic!("Unexpected not best of sampling")
|
||||
}
|
||||
|
||||
let client = ShardedClient::connect_uds("/tmp/text-generation-test".to_string()).await.unwrap();
|
||||
let waiting_served_ratio = 1.2;
|
||||
let max_batch_total_tokens = 100;
|
||||
let max_waiting_tokens = 10;
|
||||
let max_concurrent_requests = 10;
|
||||
let requires_padding = false;
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
requires_padding,
|
||||
);
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(infer))
|
||||
.into_make_service();
|
||||
|
||||
// Run the server on a random address.
|
||||
let server = TestServer::new(app);
|
||||
|
||||
// Get the request.
|
||||
let response = server
|
||||
.get("/health")
|
||||
.await;
|
||||
|
||||
assert_eq!(response.contents, "pong!");
|
||||
}
|
||||
}
|
||||
|
@ -382,6 +382,7 @@ pub enum ValidationError {
|
||||
#[cfg(test)]
|
||||
mod tests{
|
||||
use super::*;
|
||||
use crate::default_parameters;
|
||||
use std::io::Write;
|
||||
|
||||
#[tokio::test]
|
||||
@ -426,4 +427,73 @@ mod tests{
|
||||
_ => panic!("Unexpected not max new tokens")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_best_of_sampling(){
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
best_of: Some(2),
|
||||
do_sample: false,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
Err(ValidationError::BestOfSampling) => (),
|
||||
_ => panic!("Unexpected not best of sampling")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_top_p(){
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
top_p: Some(1.0),
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
Err(ValidationError::TopP) => (),
|
||||
_ => panic!("Unexpected top_p")
|
||||
}
|
||||
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
Ok(_) => (),
|
||||
_ => panic!("Unexpected top_p error")
|
||||
}
|
||||
|
||||
let valid_request = validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
top_p: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await.unwrap();
|
||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user