mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 04:22:08 +00:00
Warmup all decode buckets (#152)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
7b879fd1d8
commit
0e8f8726db
@ -120,15 +120,18 @@ impl Client {
|
||||
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
|
||||
};
|
||||
|
||||
// get all possible prefill batch sizes
|
||||
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
||||
// get all possible batch sizes
|
||||
let decode_bucket_size: u32 = read_env_var("BATCH_BUCKET_SIZE", 8);
|
||||
let max_decode_batch_size: u32 = match max_batch_size {
|
||||
Some(max_batch_size) => max_batch_size as u32,
|
||||
None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8)
|
||||
None => decode_bucket_size
|
||||
};
|
||||
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
|
||||
let decode_batch_sizes: Vec<u32> = (decode_bucket_size..max_decode_batch_size+1).step_by(decode_bucket_size as usize).collect();
|
||||
|
||||
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
|
||||
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
||||
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
||||
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
|
||||
let prefill_batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
||||
|
||||
// get all possible sequence lengths for prefill
|
||||
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
||||
@ -140,8 +143,8 @@ impl Client {
|
||||
}
|
||||
|
||||
// execute batch for each combination of batch size and sequence length
|
||||
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
||||
for batch_size in &batch_sizes {
|
||||
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len() * seq_lengths.len());
|
||||
for batch_size in &prefill_batch_sizes {
|
||||
for seq_length in &seq_lengths {
|
||||
shapes.push((*batch_size, *seq_length));
|
||||
}
|
||||
@ -183,9 +186,59 @@ impl Client {
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
// send batches to warmup all possible decode shapes
|
||||
if decode_batch_sizes.len() > 1 {
|
||||
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
(requests_send, seq_bucket_size),
|
||||
&mut id_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
)
|
||||
];
|
||||
|
||||
let get_current_decode_batch_size = |num: u32| -> u32 {
|
||||
decode_batch_sizes.iter()
|
||||
.filter(|&&x| x >= num)
|
||||
.min()
|
||||
.copied()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
|
||||
while current_decode_batch_size < max_decode_batch_size {
|
||||
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
|
||||
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
|
||||
batches.push(
|
||||
self.create_warmup_batch(
|
||||
(num_requests, seq_bucket_size),
|
||||
&mut id_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
false,
|
||||
)
|
||||
);
|
||||
|
||||
requests_send += num_requests;
|
||||
current_decode_batch_size = get_current_decode_batch_size(requests_send);
|
||||
}
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batches,
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
}).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
// send batches with default params to warm up Greedy search
|
||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
||||
for batch_size in &batch_sizes {
|
||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
|
||||
for batch_size in &prefill_batch_sizes {
|
||||
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||
}
|
||||
for greedy_shape in greedy_shapes.iter() {
|
||||
|
@ -1124,6 +1124,12 @@ class CausalLM(Model):
|
||||
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
||||
|
||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||
def get_unfinished_requests(requests: List[CausalLMRequest]) -> List[int]:
|
||||
return [
|
||||
request.data.id for request in requests
|
||||
if request.stopping_criteria.current_tokens < request.stopping_criteria.max_new_tokens
|
||||
]
|
||||
|
||||
# prefill
|
||||
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
||||
# decode
|
||||
@ -1131,18 +1137,22 @@ class CausalLM(Model):
|
||||
# shifts
|
||||
self.shifting_warmup(decode_batch)
|
||||
|
||||
# if decode bs is 1 warmup ends here
|
||||
if len(batches) == 0:
|
||||
while decode_batch is not None:
|
||||
_, decode_batch, _ = self.generate_token([decode_batch])
|
||||
return
|
||||
while len(batches) > 0:
|
||||
# prefill
|
||||
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
||||
# concatenate and decode
|
||||
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
|
||||
# filter finished requests
|
||||
request_ids = get_unfinished_requests(decode_batch.requests)
|
||||
if len(request_ids) < len(decode_batch.requests):
|
||||
decode_batch = decode_batch.filter(request_ids)
|
||||
|
||||
# prefill
|
||||
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
||||
# concatenate and decode
|
||||
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
|
||||
# decodes
|
||||
while decode_batch is not None:
|
||||
# filter finished requests
|
||||
request_ids = get_unfinished_requests(decode_batch.requests)
|
||||
if len(request_ids) < len(decode_batch.requests):
|
||||
decode_batch = decode_batch.filter(request_ids)
|
||||
# decode
|
||||
_, decode_batch, _ = self.generate_token([decode_batch])
|
||||
|
||||
def shifting_warmup(self, batch: CausalLMBatch) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user