From ba22ef54d41640958a16973888fa2f0875080832 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 21 Sep 2023 20:50:50 -0700 Subject: [PATCH] pass max_total_tokens info through warmup, python could get max_total_tokens as truncate+max_new_tokens in warmup Signed-off-by: Wang, Yi A --- router/client/src/client.rs | 10 ++++++---- router/client/src/sharded_client.rs | 3 ++- router/src/main.rs | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index d427d3a4..5c6ee38d 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -103,17 +103,19 @@ impl Client { &mut self, max_input_length: u32, max_prefill_tokens: u32, + max_total_tokens: u32, ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); - + let mut truncate = 0; // Create requests while n_tokens < max_prefill_tokens { + truncate = min(max_input_length, max_prefill_tokens - n_tokens); requests.push(Request { id: 0, // We truncate the input on the server side to be sure that it has the correct size inputs: "_test ".to_string().repeat(max_input_length as usize), - truncate: min(max_input_length, max_prefill_tokens - n_tokens), + truncate: truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -126,9 +128,9 @@ impl Client { watermark: true, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: 2, + max_new_tokens: max_total_tokens-truncate, stop_sequences: vec![], - ignore_eos_token: false, + ignore_eos_token: true, }), prefill_logprobs: true, top_n_tokens: 20, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 112b0035..ef974d56 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -95,11 +95,12 @@ impl ShardedClient { &mut self, max_input_length: u32, max_prefill_tokens: u32, + max_total_tokens: u32, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) + .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))) .collect(); // Take the minimum value let results = join_all(futures) diff --git a/router/src/main.rs b/router/src/main.rs index 4903c066..ea6047a4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); let max_supported_batch_total_tokens = match sharded_client - .warmup(max_input_length as u32, max_batch_prefill_tokens) + .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) .await .map_err(RouterError::Warmup)? {