From 40f693b6b93aaf826831b3a57a674aded61c2573 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 8 Feb 2024 15:04:27 +0000 Subject: [PATCH] Fix PR. --- docs/source/basic_tutorials/launcher.md | 8 ++++++++ launcher/src/main.rs | 2 +- router/src/server.rs | 2 +- router/src/validation.rs | 5 ++--- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 712b4fc4..226d0e96 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -354,6 +354,14 @@ Options: [env: NGROK_EDGE=] +``` +## BATCH_DIMENSION +```shell + --batch-dimension + Specific flag for hardware targets that do not support unpadded inference For those we do not send the tokenizer to the router so that all the scheduling assumes those pad tokens exist (and potentially even more) + + [env: BATCH_DIMENSION=] + ``` ## TOKENIZER_CONFIG_PATH ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e6799cb3..15ca00f1 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1040,7 +1040,7 @@ fn spawn_webserver( args.model_id, ]; - if args.batch_dimension{ + if args.batch_dimension { router_args.push("--batch-dimension".to_string()); } diff --git a/router/src/server.rs b/router/src/server.rs index 42ddca68..6bd7ec0b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -834,7 +834,7 @@ pub async fn run( max_top_n_tokens, max_input_length, max_total_tokens, - batch_dimension + batch_dimension, ); let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); diff --git a/router/src/validation.rs b/router/src/validation.rs index 542ed656..166f0b88 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -69,7 +69,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, - batch_dimension + batch_dimension, } } @@ -107,7 +107,7 @@ impl Validation { ) -> Result<(String, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { - if self.batch_dimension{ + if self.batch_dimension { let input_length = encoding.len(); // Get total tokens @@ -135,7 +135,6 @@ impl Validation { )); } - // metrics::histogram!("tgi_request_input_length", input_length as f64); return Ok((inputs, input_length, max_new_tokens)); }