mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Warmup all decode buckets (#152)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
7b879fd1d8
commit
0e8f8726db
@ -120,15 +120,18 @@ impl Client {
|
|||||||
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
|
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
|
||||||
};
|
};
|
||||||
|
|
||||||
// get all possible prefill batch sizes
|
// get all possible batch sizes
|
||||||
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
let decode_bucket_size: u32 = read_env_var("BATCH_BUCKET_SIZE", 8);
|
||||||
let max_decode_batch_size: u32 = match max_batch_size {
|
let max_decode_batch_size: u32 = match max_batch_size {
|
||||||
Some(max_batch_size) => max_batch_size as u32,
|
Some(max_batch_size) => max_batch_size as u32,
|
||||||
None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8)
|
None => decode_bucket_size
|
||||||
};
|
};
|
||||||
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
|
let decode_batch_sizes: Vec<u32> = (decode_bucket_size..max_decode_batch_size+1).step_by(decode_bucket_size as usize).collect();
|
||||||
|
|
||||||
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
|
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
|
||||||
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
||||||
|
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
|
||||||
|
let prefill_batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
||||||
|
|
||||||
// get all possible sequence lengths for prefill
|
// get all possible sequence lengths for prefill
|
||||||
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
||||||
@ -140,8 +143,8 @@ impl Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute batch for each combination of batch size and sequence length
|
// execute batch for each combination of batch size and sequence length
|
||||||
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len() * seq_lengths.len());
|
||||||
for batch_size in &batch_sizes {
|
for batch_size in &prefill_batch_sizes {
|
||||||
for seq_length in &seq_lengths {
|
for seq_length in &seq_lengths {
|
||||||
shapes.push((*batch_size, *seq_length));
|
shapes.push((*batch_size, *seq_length));
|
||||||
}
|
}
|
||||||
@ -183,9 +186,59 @@ impl Client {
|
|||||||
let _response = self.stub.warmup(request).await?.into_inner();
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// send batches to warmup all possible decode shapes
|
||||||
|
if decode_batch_sizes.len() > 1 {
|
||||||
|
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
|
||||||
|
let mut batches: Vec<Batch> = vec![
|
||||||
|
self.create_warmup_batch(
|
||||||
|
(requests_send, seq_bucket_size),
|
||||||
|
&mut id_counter,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
seq_bucket_size,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
];
|
||||||
|
|
||||||
|
let get_current_decode_batch_size = |num: u32| -> u32 {
|
||||||
|
decode_batch_sizes.iter()
|
||||||
|
.filter(|&&x| x >= num)
|
||||||
|
.min()
|
||||||
|
.copied()
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
|
||||||
|
while current_decode_batch_size < max_decode_batch_size {
|
||||||
|
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
|
||||||
|
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
|
||||||
|
batches.push(
|
||||||
|
self.create_warmup_batch(
|
||||||
|
(num_requests, seq_bucket_size),
|
||||||
|
&mut id_counter,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
seq_bucket_size,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
requests_send += num_requests;
|
||||||
|
current_decode_batch_size = get_current_decode_batch_size(requests_send);
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batches,
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
}).inject_context();
|
||||||
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
}
|
||||||
|
|
||||||
// send batches with default params to warm up Greedy search
|
// send batches with default params to warm up Greedy search
|
||||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
|
||||||
for batch_size in &batch_sizes {
|
for batch_size in &prefill_batch_sizes {
|
||||||
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||||
}
|
}
|
||||||
for greedy_shape in greedy_shapes.iter() {
|
for greedy_shape in greedy_shapes.iter() {
|
||||||
|
@ -1124,6 +1124,12 @@ class CausalLM(Model):
|
|||||||
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||||
|
def get_unfinished_requests(requests: List[CausalLMRequest]) -> List[int]:
|
||||||
|
return [
|
||||||
|
request.data.id for request in requests
|
||||||
|
if request.stopping_criteria.current_tokens < request.stopping_criteria.max_new_tokens
|
||||||
|
]
|
||||||
|
|
||||||
# prefill
|
# prefill
|
||||||
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
||||||
# decode
|
# decode
|
||||||
@ -1131,18 +1137,22 @@ class CausalLM(Model):
|
|||||||
# shifts
|
# shifts
|
||||||
self.shifting_warmup(decode_batch)
|
self.shifting_warmup(decode_batch)
|
||||||
|
|
||||||
# if decode bs is 1 warmup ends here
|
while len(batches) > 0:
|
||||||
if len(batches) == 0:
|
|
||||||
while decode_batch is not None:
|
|
||||||
_, decode_batch, _ = self.generate_token([decode_batch])
|
|
||||||
return
|
|
||||||
|
|
||||||
# prefill
|
# prefill
|
||||||
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
|
||||||
# concatenate and decode
|
# concatenate and decode
|
||||||
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
|
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
|
||||||
# decodes
|
# filter finished requests
|
||||||
|
request_ids = get_unfinished_requests(decode_batch.requests)
|
||||||
|
if len(request_ids) < len(decode_batch.requests):
|
||||||
|
decode_batch = decode_batch.filter(request_ids)
|
||||||
|
|
||||||
while decode_batch is not None:
|
while decode_batch is not None:
|
||||||
|
# filter finished requests
|
||||||
|
request_ids = get_unfinished_requests(decode_batch.requests)
|
||||||
|
if len(request_ids) < len(decode_batch.requests):
|
||||||
|
decode_batch = decode_batch.filter(request_ids)
|
||||||
|
# decode
|
||||||
_, decode_batch, _ = self.generate_token([decode_batch])
|
_, decode_batch, _ = self.generate_token([decode_batch])
|
||||||
|
|
||||||
def shifting_warmup(self, batch: CausalLMBatch) -> None:
|
def shifting_warmup(self, batch: CausalLMBatch) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user