mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Adjust warmup to all possible bucket sizes and decode batch size = 1 (#113)
This commit is contained in:
parent
9796b0e10d
commit
56f00a552b
@ -106,6 +106,7 @@ impl Client {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
) -> Result<Option<u32>> {
|
||||
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
||||
if !warmup_enabled {
|
||||
@ -123,7 +124,12 @@ impl Client {
|
||||
|
||||
// get all possible sequence lengths for prefill
|
||||
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
||||
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
||||
let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
||||
if let Some(&last) = seq_lengths.last() {
|
||||
if last < max_input_length {
|
||||
seq_lengths.push(max_input_length);
|
||||
}
|
||||
}
|
||||
|
||||
// execute batch for each combination of batch size and sequence length
|
||||
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
||||
@ -134,11 +140,29 @@ impl Client {
|
||||
}
|
||||
|
||||
let mut id_counter: u64 = 0;
|
||||
let num_batches = match max_batch_total_tokens {
|
||||
Some(val) => {
|
||||
if val == max_total_tokens {
|
||||
1
|
||||
} else {
|
||||
2
|
||||
}
|
||||
}
|
||||
None => 2, // If max_batch_total_tokens is None, create two batches
|
||||
};
|
||||
for shape in shapes.iter() {
|
||||
// create two batches in order to trigger concatenate operation
|
||||
// in case decode bs=1 create one batch
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false),
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false)
|
||||
self.create_warmup_batch(
|
||||
*shape,
|
||||
&mut id_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
);
|
||||
num_batches
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
@ -151,8 +175,15 @@ impl Client {
|
||||
}
|
||||
for greedy_shape in greedy_shapes.iter() {
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
||||
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
||||
self.create_warmup_batch(
|
||||
*greedy_shape,
|
||||
&mut id_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
true,
|
||||
);
|
||||
num_batches
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
@ -233,14 +264,21 @@ impl Client {
|
||||
// generate random tokens
|
||||
let mut rng = rand::thread_rng();
|
||||
let range = Uniform::new(2, 8192);
|
||||
let tokens = input_length - seq_bucket_size / 2;
|
||||
let tokens = if input_length % seq_bucket_size == 0 {
|
||||
input_length - seq_bucket_size / 2
|
||||
} else {
|
||||
input_length - (input_length % seq_bucket_size) / 2
|
||||
};
|
||||
(0..tokens)
|
||||
.map(|_| rng.sample(&range).to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
} else {
|
||||
// repeat test string to get expected input shape
|
||||
let bucket_id = input_length / seq_bucket_size;
|
||||
let mut bucket_id = input_length / seq_bucket_size;
|
||||
if input_length % seq_bucket_size != 0 {
|
||||
bucket_id += 1
|
||||
}
|
||||
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
|
||||
"_test ".to_string().repeat(repeats as usize)
|
||||
}
|
||||
|
@ -96,12 +96,13 @@ impl ShardedClient {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
|
@ -220,7 +220,7 @@ fn main() -> Result<(), RouterError> {
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
|
||||
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, max_batch_total_tokens)
|
||||
.await
|
||||
.map_err(RouterError::Warmup)?
|
||||
{
|
||||
|
@ -461,7 +461,11 @@ class CausalLMBatch(Batch):
|
||||
left_padding = max_input_length - input_len
|
||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1
|
||||
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
if rounded_seq_len <= max_input_length:
|
||||
bucket_size = rounded_seq_len - 1
|
||||
else:
|
||||
bucket_size = max_input_length - 1
|
||||
left_padding = bucket_size - input_len
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
@ -1049,15 +1053,17 @@ class CausalLM(Model):
|
||||
return generations, batch if not stopped else None
|
||||
|
||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||
if len(batches) < 2:
|
||||
return
|
||||
|
||||
# prefill
|
||||
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||
# decode
|
||||
_, decode_batch = self.generate_token([prefill_batch])
|
||||
# shifts
|
||||
self.shifting_warmup(decode_batch)
|
||||
|
||||
# if decode bs is 1 warmup ends here
|
||||
if len(batches) == 0:
|
||||
return
|
||||
|
||||
# prefill
|
||||
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||
# concatenate and decode
|
||||
|
Loading…
Reference in New Issue
Block a user