From 9d613d0f9ba6c543862ec9be1595edb03f794584 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Apr 2023 14:25:55 +0200 Subject: [PATCH] Updating code. --- .github/workflows/tests.yaml | 6 +++ router/src/lib.rs | 9 ----- router/src/server.rs | 11 ++---- router/src/validation.rs | 51 ++++++++++++++++++------- server/text_generation_server/server.py | 2 +- 5 files changed, 49 insertions(+), 30 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e82e8b20..26cbf3b2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,6 +67,12 @@ jobs: run: | pip install pytest HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests + - name: Run Rust fmt + run: | + cargo fmt --check + - name: Run Rust clippy + run: | + cargo clippy - name: Run Rust tests run: | cargo test diff --git a/router/src/lib.rs b/router/src/lib.rs index 4f73fa16..bf2112a9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -20,20 +20,11 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Clone, Debug, Serialize, ToSchema)] -pub struct HealthResponse {} - #[derive(Clone, Debug)] pub struct Health { pub client: ShardedClient, } -impl Health { - pub fn new(client: ShardedClient) -> Self { - Self { client } - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info diff --git a/router/src/server.rs b/router/src/server.rs index ce8c59dc..97fd43d0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -90,7 +90,7 @@ async fn get_model_info(info: Extension) -> Json { responses( (status = 200, description = "Everything is working fine"), (status = 500, description = "Text generation inference is down", body = ErrorResponse, - example = json ! ({"error": "unhealthy"})), + example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] #[instrument] @@ -98,11 +98,6 @@ async fn get_model_info(info: Extension) -> Json { async fn health( mut health: Extension, ) -> Result, (StatusCode, Json)> { - // TODO: while this is the best health check we can do, it is a bit on the heavy side and might - // be a bit too slow for a health check. - // What we should do instead is check if the gRPC channels are still healthy. - - // Send a small inference request health.client.health().await.map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, @@ -589,7 +584,9 @@ pub async fn run( max_input_length, max_total_tokens, ); - let health_ext = Health::new(client.clone()); + let health_ext = Health { + client: client.clone(), + }; let infer = Infer::new( client, validation, diff --git a/router/src/validation.rs b/router/src/validation.rs index 983c2612..37147a03 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -380,50 +380,75 @@ pub enum ValidationError { } #[cfg(test)] -mod tests{ +mod tests { use super::*; use std::io::Write; #[tokio::test] - async fn test_validation_max_new_tokens(){ + async fn test_validation_max_new_tokens() { let tokenizer = None; let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); let max_new_tokens = 10; - match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + match validation + .validate_input("Hello".to_string(), None, max_new_tokens) + .await + { Err(ValidationError::MaxNewTokens(1, 10)) => (), - _ => panic!("Unexpected not max new tokens") + _ => panic!("Unexpected not max new tokens"), } } - async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); + async fn get_tokenizer() -> Tokenizer { + if !std::path::Path::new("tokenizer.json").exists() { + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json") + .await + .unwrap() + .bytes() + .await + .unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); file.write_all(&content).unwrap(); } Tokenizer::from_file("tokenizer.json").unwrap() } #[tokio::test] - async fn test_validation_input_length(){ + async fn test_validation_input_length() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); let max_new_tokens = 10; - match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + match validation + .validate_input("Hello".to_string(), None, max_new_tokens) + .await + { Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), - _ => panic!("Unexpected not max new tokens") + _ => panic!("Unexpected not max new tokens"), } } } diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 605e1320..70f08ed7 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -31,7 +31,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Health(self, request, context): if self.model.device.type == "cuda": - torch.zeros((2, 2)).to(device=f"cuda:{os.environ['RANK']}") + torch.zeros((2, 2)).cuda() return generate_pb2.HealthResponse() async def ServiceDiscovery(self, request, context):