mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Update ever so slightly current queue tests.
This commit is contained in:
parent
1c97d7b0c0
commit
67356dc9a2
3
.github/workflows/tests.yaml
vendored
3
.github/workflows/tests.yaml
vendored
@ -67,6 +67,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||||
|
- name: Run Clippy
|
||||||
|
run: |
|
||||||
|
cargo clippy
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
25
Cargo.lock
generated
25
Cargo.lock
generated
@ -183,21 +183,6 @@ dependencies = [
|
|||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "axum-test-server"
|
|
||||||
version = "2.0.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "57f8f8627d32fe7e2c36b33de0e87dcdee4d6ac8619b9b892e5cc299ea4eed52"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"axum",
|
|
||||||
"hyper",
|
|
||||||
"portpicker",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"tokio",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-tracing-opentelemetry"
|
name = "axum-tracing-opentelemetry"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
@ -1751,15 +1736,6 @@ version = "0.3.19"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
|
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "portpicker"
|
|
||||||
version = "0.1.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9"
|
|
||||||
dependencies = [
|
|
||||||
"rand",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.17"
|
version = "0.2.17"
|
||||||
@ -2431,7 +2407,6 @@ version = "0.6.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-test-server",
|
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap",
|
"clap",
|
||||||
"flume",
|
"flume",
|
||||||
|
@ -42,6 +42,3 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
|||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
axum-test-server = "2.0.0"
|
|
||||||
|
@ -141,6 +141,9 @@ impl State {
|
|||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
||||||
|
|
||||||
|
println!("Next batch {min_size:?} {token_budget:?}");
|
||||||
|
println!("{:?}",self.entries);
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@ -430,7 +433,17 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
|
// Not enough requests pending
|
||||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
||||||
|
// Not enough token budget
|
||||||
|
assert!(queue.next_batch(Some(1), 0).await.is_none());
|
||||||
|
// Ok
|
||||||
|
let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap();
|
||||||
|
assert_eq!(entries2.len(), 1);
|
||||||
|
assert!(entries2.contains_key(&2));
|
||||||
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch2.id, 1);
|
||||||
|
assert_eq!(batch2.size, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -742,62 +742,3 @@ impl From<InferError> for Event {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests{
|
|
||||||
use super::*;
|
|
||||||
use crate::tests::get_tokenizer;
|
|
||||||
use axum_test_server::TestServer;
|
|
||||||
use crate::default_parameters;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_health(){
|
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
|
||||||
let max_best_of = 2;
|
|
||||||
let max_stop_sequence = 3;
|
|
||||||
let max_input_length = 4;
|
|
||||||
let max_total_tokens = 5;
|
|
||||||
let workers = 1;
|
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
|
||||||
match validation.validate(GenerateRequest{
|
|
||||||
inputs: "Hello".to_string(),
|
|
||||||
parameters: GenerateParameters{
|
|
||||||
best_of: Some(2),
|
|
||||||
do_sample: false,
|
|
||||||
..default_parameters()
|
|
||||||
}
|
|
||||||
}).await{
|
|
||||||
Err(ValidationError::BestOfSampling) => (),
|
|
||||||
_ => panic!("Unexpected not best of sampling")
|
|
||||||
}
|
|
||||||
|
|
||||||
let client = ShardedClient::connect_uds("/tmp/text-generation-test".to_string()).await.unwrap();
|
|
||||||
let waiting_served_ratio = 1.2;
|
|
||||||
let max_batch_total_tokens = 100;
|
|
||||||
let max_waiting_tokens = 10;
|
|
||||||
let max_concurrent_requests = 10;
|
|
||||||
let requires_padding = false;
|
|
||||||
let infer = Infer::new(
|
|
||||||
client,
|
|
||||||
validation,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_concurrent_requests,
|
|
||||||
requires_padding,
|
|
||||||
);
|
|
||||||
let app = Router::new()
|
|
||||||
.route("/health", get(health))
|
|
||||||
.layer(Extension(infer))
|
|
||||||
.into_make_service();
|
|
||||||
|
|
||||||
// Run the server on a random address.
|
|
||||||
let server = TestServer::new(app);
|
|
||||||
|
|
||||||
// Get the request.
|
|
||||||
let response = server
|
|
||||||
.get("/health")
|
|
||||||
.await;
|
|
||||||
|
|
||||||
assert_eq!(response.contents, "pong!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -383,7 +383,7 @@ pub enum ValidationError {
|
|||||||
mod tests{
|
mod tests{
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::default_parameters;
|
use crate::default_parameters;
|
||||||
use std::io::Write;
|
use crate::tests::get_tokenizer;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_max_new_tokens(){
|
async fn test_validation_max_new_tokens(){
|
||||||
@ -402,15 +402,6 @@ mod tests{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer() -> Tokenizer{
|
|
||||||
if !std::path::Path::new("tokenizer.json").exists(){
|
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
|
||||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
|
||||||
file.write_all(&content).unwrap();
|
|
||||||
}
|
|
||||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_input_length(){
|
async fn test_validation_input_length(){
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
Loading…
Reference in New Issue
Block a user