diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 1823fca3..a42e23cb 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -106,6 +106,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: Option, ) -> Result> { let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); if !warmup_enabled { @@ -123,7 +124,12 @@ impl Client { // get all possible sequence lengths for prefill let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); - let seq_lengths: Vec = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect(); + let mut seq_lengths: Vec = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect(); + if let Some(&last) = seq_lengths.last() { + if last < max_input_length { + seq_lengths.push(max_input_length); + } + } // execute batch for each combination of batch size and sequence length let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len()); @@ -134,11 +140,29 @@ impl Client { } let mut id_counter: u64 = 0; + let num_batches = match max_batch_total_tokens { + Some(val) => { + if val == max_total_tokens { + 1 + } else { + 2 + } + } + None => 2, // If max_batch_total_tokens is None, create two batches + }; for shape in shapes.iter() { // create two batches in order to trigger concatenate operation + // in case decode bs=1 create one batch let batches: Vec = vec![ - self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false), - self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false) + self.create_warmup_batch( + *shape, + &mut id_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + false, + ); + num_batches ]; let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let _response = self.stub.warmup(request).await?.into_inner(); @@ -151,8 +175,15 @@ impl Client { } for greedy_shape in greedy_shapes.iter() { let batches: Vec = vec![ - self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true), - self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true), + self.create_warmup_batch( + *greedy_shape, + &mut id_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + true, + ); + num_batches ]; let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let _response = self.stub.warmup(request).await?.into_inner(); @@ -233,14 +264,21 @@ impl Client { // generate random tokens let mut rng = rand::thread_rng(); let range = Uniform::new(2, 8192); - let tokens = input_length - seq_bucket_size / 2; + let tokens = if input_length % seq_bucket_size == 0 { + input_length - seq_bucket_size / 2 + } else { + input_length - (input_length % seq_bucket_size) / 2 + }; (0..tokens) .map(|_| rng.sample(&range).to_string()) .collect::>() .join(", ") } else { // repeat test string to get expected input shape - let bucket_id = input_length / seq_bucket_size; + let mut bucket_id = input_length / seq_bucket_size; + if input_length % seq_bucket_size != 0 { + bucket_id += 1 + } let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2); "_test ".to_string().repeat(repeats as usize) } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index b4bdcd42..27462bce 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -96,12 +96,13 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: Option, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens)) }) .collect(); // Take the minimum value diff --git a/router/src/main.rs b/router/src/main.rs index fbe0bca5..b6937a51 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -220,7 +220,7 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); let max_supported_batch_total_tokens = match sharded_client - .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) + .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, max_batch_total_tokens) .await .map_err(RouterError::Warmup)? { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0c5a7288..f8ef8050 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -461,7 +461,11 @@ class CausalLMBatch(Batch): left_padding = max_input_length - input_len if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1 + rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 left_padding = bucket_size - input_len input_ids = tokenized_inputs["input_ids"] @@ -1049,15 +1053,17 @@ class CausalLM(Model): return generations, batch if not stopped else None def warmup(self, batches: List[CausalLMBatch]) -> None: - if len(batches) < 2: - return - # prefill _, prefill_batch = self.generate_token([batches.pop(0)]) # decode _, decode_batch = self.generate_token([prefill_batch]) # shifts self.shifting_warmup(decode_batch) + + # if decode bs is 1 warmup ends here + if len(batches) == 0: + return + # prefill _, prefill_batch = self.generate_token([batches.pop(0)]) # concatenate and decode