diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 712b4fc4..226d0e96 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -354,6 +354,14 @@ Options: [env: NGROK_EDGE=] +``` +## BATCH_DIMENSION +```shell + --batch-dimension + Specific flag for hardware targets that do not support unpadded inference For those we do not send the tokenizer to the router so that all the scheduling assumes those pad tokens exist (and potentially even more) + + [env: BATCH_DIMENSION=] + ``` ## TOKENIZER_CONFIG_PATH ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e6799cb3..15ca00f1 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1040,7 +1040,7 @@ fn spawn_webserver( args.model_id, ]; - if args.batch_dimension{ + if args.batch_dimension { router_args.push("--batch-dimension".to_string()); } diff --git a/router/src/server.rs b/router/src/server.rs index 42ddca68..6bd7ec0b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -834,7 +834,7 @@ pub async fn run( max_top_n_tokens, max_input_length, max_total_tokens, - batch_dimension + batch_dimension, ); let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); diff --git a/router/src/validation.rs b/router/src/validation.rs index 542ed656..166f0b88 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -69,7 +69,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, - batch_dimension + batch_dimension, } } @@ -107,7 +107,7 @@ impl Validation { ) -> Result<(String, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { - if self.batch_dimension{ + if self.batch_dimension { let input_length = encoding.len(); // Get total tokens @@ -135,7 +135,6 @@ impl Validation { )); } - // metrics::histogram!("tgi_request_input_length", input_length as f64); return Ok((inputs, input_length, max_new_tokens)); }