mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Adjust warmup to all possible bucket sizes and decode batch size = 1 (#113)
This commit is contained in:
parent
9796b0e10d
commit
56f00a552b
@ -106,6 +106,7 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
||||||
if !warmup_enabled {
|
if !warmup_enabled {
|
||||||
@ -123,7 +124,12 @@ impl Client {
|
|||||||
|
|
||||||
// get all possible sequence lengths for prefill
|
// get all possible sequence lengths for prefill
|
||||||
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
||||||
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
||||||
|
if let Some(&last) = seq_lengths.last() {
|
||||||
|
if last < max_input_length {
|
||||||
|
seq_lengths.push(max_input_length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// execute batch for each combination of batch size and sequence length
|
// execute batch for each combination of batch size and sequence length
|
||||||
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
||||||
@ -134,11 +140,29 @@ impl Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut id_counter: u64 = 0;
|
let mut id_counter: u64 = 0;
|
||||||
|
let num_batches = match max_batch_total_tokens {
|
||||||
|
Some(val) => {
|
||||||
|
if val == max_total_tokens {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => 2, // If max_batch_total_tokens is None, create two batches
|
||||||
|
};
|
||||||
for shape in shapes.iter() {
|
for shape in shapes.iter() {
|
||||||
// create two batches in order to trigger concatenate operation
|
// create two batches in order to trigger concatenate operation
|
||||||
|
// in case decode bs=1 create one batch
|
||||||
let batches: Vec<Batch> = vec![
|
let batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false),
|
self.create_warmup_batch(
|
||||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false)
|
*shape,
|
||||||
|
&mut id_counter,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
seq_bucket_size,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
num_batches
|
||||||
];
|
];
|
||||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||||
let _response = self.stub.warmup(request).await?.into_inner();
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
@ -151,8 +175,15 @@ impl Client {
|
|||||||
}
|
}
|
||||||
for greedy_shape in greedy_shapes.iter() {
|
for greedy_shape in greedy_shapes.iter() {
|
||||||
let batches: Vec<Batch> = vec![
|
let batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
self.create_warmup_batch(
|
||||||
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
*greedy_shape,
|
||||||
|
&mut id_counter,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
seq_bucket_size,
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
num_batches
|
||||||
];
|
];
|
||||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||||
let _response = self.stub.warmup(request).await?.into_inner();
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
@ -233,14 +264,21 @@ impl Client {
|
|||||||
// generate random tokens
|
// generate random tokens
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let range = Uniform::new(2, 8192);
|
let range = Uniform::new(2, 8192);
|
||||||
let tokens = input_length - seq_bucket_size / 2;
|
let tokens = if input_length % seq_bucket_size == 0 {
|
||||||
|
input_length - seq_bucket_size / 2
|
||||||
|
} else {
|
||||||
|
input_length - (input_length % seq_bucket_size) / 2
|
||||||
|
};
|
||||||
(0..tokens)
|
(0..tokens)
|
||||||
.map(|_| rng.sample(&range).to_string())
|
.map(|_| rng.sample(&range).to_string())
|
||||||
.collect::<Vec<String>>()
|
.collect::<Vec<String>>()
|
||||||
.join(", ")
|
.join(", ")
|
||||||
} else {
|
} else {
|
||||||
// repeat test string to get expected input shape
|
// repeat test string to get expected input shape
|
||||||
let bucket_id = input_length / seq_bucket_size;
|
let mut bucket_id = input_length / seq_bucket_size;
|
||||||
|
if input_length % seq_bucket_size != 0 {
|
||||||
|
bucket_id += 1
|
||||||
|
}
|
||||||
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
|
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
|
||||||
"_test ".to_string().repeat(repeats as usize)
|
"_test ".to_string().repeat(repeats as usize)
|
||||||
}
|
}
|
||||||
|
@ -96,12 +96,13 @@ impl ShardedClient {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| {
|
.map(|client| {
|
||||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
// Take the minimum value
|
||||||
|
@ -220,7 +220,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
let max_supported_batch_total_tokens = match sharded_client
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
|
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, max_batch_total_tokens)
|
||||||
.await
|
.await
|
||||||
.map_err(RouterError::Warmup)?
|
.map_err(RouterError::Warmup)?
|
||||||
{
|
{
|
||||||
|
@ -461,7 +461,11 @@ class CausalLMBatch(Batch):
|
|||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1
|
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||||
|
if rounded_seq_len <= max_input_length:
|
||||||
|
bucket_size = rounded_seq_len - 1
|
||||||
|
else:
|
||||||
|
bucket_size = max_input_length - 1
|
||||||
left_padding = bucket_size - input_len
|
left_padding = bucket_size - input_len
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
@ -1049,15 +1053,17 @@ class CausalLM(Model):
|
|||||||
return generations, batch if not stopped else None
|
return generations, batch if not stopped else None
|
||||||
|
|
||||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||||
if len(batches) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
# prefill
|
# prefill
|
||||||
_, prefill_batch = self.generate_token([batches.pop(0)])
|
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||||
# decode
|
# decode
|
||||||
_, decode_batch = self.generate_token([prefill_batch])
|
_, decode_batch = self.generate_token([prefill_batch])
|
||||||
# shifts
|
# shifts
|
||||||
self.shifting_warmup(decode_batch)
|
self.shifting_warmup(decode_batch)
|
||||||
|
|
||||||
|
# if decode bs is 1 warmup ends here
|
||||||
|
if len(batches) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
# prefill
|
# prefill
|
||||||
_, prefill_batch = self.generate_token([batches.pop(0)])
|
_, prefill_batch = self.generate_token([batches.pop(0)])
|
||||||
# concatenate and decode
|
# concatenate and decode
|
||||||
|
Loading…
Reference in New Issue
Block a user