mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
pass max_total_tokens info through warmup, python could get max_total_tokens as truncate+max_new_tokens in warmup
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
123749a3c9
commit
ba22ef54d4
@ -103,17 +103,19 @@ impl Client {
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
|
||||
let mut truncate = 0;
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
||||
truncate: truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
@ -126,9 +128,9 @@ impl Client {
|
||||
watermark: true,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 2,
|
||||
max_new_tokens: max_total_tokens-truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
|
@ -95,11 +95,12 @@ impl ShardedClient {
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)))
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
|
@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> {
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
||||
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
|
||||
.await
|
||||
.map_err(RouterError::Warmup)?
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user