From 5e28f44a834c20602d4cc18d28703e024d3bbbe0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 20 Oct 2023 10:28:45 +0200 Subject: [PATCH] #1049 CI (#1178) See #1049 --------- Signed-off-by: Wang, Yi A Co-authored-by: Wang, Yi --- router/client/src/client.rs | 10 ++++++---- router/client/src/sharded_client.rs | 5 ++++- router/src/main.rs | 2 +- router/src/validation.rs | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index d427d3a4d..f8f5df957 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -103,17 +103,19 @@ impl Client { &mut self, max_input_length: u32, max_prefill_tokens: u32, + max_total_tokens: u32, ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); - + let mut truncate = 0; // Create requests while n_tokens < max_prefill_tokens { + truncate = min(max_input_length, max_prefill_tokens - n_tokens); requests.push(Request { id: 0, // We truncate the input on the server side to be sure that it has the correct size inputs: "_test ".to_string().repeat(max_input_length as usize), - truncate: min(max_input_length, max_prefill_tokens - n_tokens), + truncate: truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -126,9 +128,9 @@ impl Client { watermark: true, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: 2, + max_new_tokens: max_total_tokens - truncate, stop_sequences: vec![], - ignore_eos_token: false, + ignore_eos_token: true, }), prefill_logprobs: true, top_n_tokens: 20, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 112b0035a..b4bdcd424 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -95,11 +95,14 @@ impl ShardedClient { &mut self, max_input_length: u32, max_prefill_tokens: u32, + max_total_tokens: u32, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) + .map(|client| { + Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + }) .collect(); // Take the minimum value let results = join_all(futures) diff --git a/router/src/main.rs b/router/src/main.rs index f30286749..d90632ef4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); let max_supported_batch_total_tokens = match sharded_client - .warmup(max_input_length as u32, max_batch_prefill_tokens) + .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) .await .map_err(RouterError::Warmup)? { diff --git a/router/src/validation.rs b/router/src/validation.rs index d0ea137d9..37465272a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -122,7 +122,7 @@ impl Validation { if let Some(truncate) = truncate { self.max_total_tokens.saturating_sub(truncate) as u32 } else { - return Err(ValidationError::UnsetMaxNewTokens) + return Err(ValidationError::UnsetMaxNewTokens); } }; let input_length = truncate.unwrap_or(self.max_input_length);