mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 07:22:07 +00:00
Fix warmup shapes for corner cases (#136)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
4169ff8e6f
commit
bad7fe720a
@ -121,7 +121,12 @@ impl Client {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// get all possible prefill batch sizes
|
// get all possible prefill batch sizes
|
||||||
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
||||||
|
let max_decode_batch_size: u32 = match max_batch_size {
|
||||||
|
Some(max_batch_size) => max_batch_size as u32,
|
||||||
|
None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8)
|
||||||
|
};
|
||||||
|
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
|
||||||
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
|
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
|
||||||
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
||||||
|
|
||||||
@ -142,13 +147,10 @@ impl Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if max_batch_size is None, create two batches
|
|
||||||
let num_batches = max_batch_size.unwrap_or(2).min(2);
|
|
||||||
let mut id_counter: u64 = 0;
|
let mut id_counter: u64 = 0;
|
||||||
for shape in shapes.iter() {
|
for shape in shapes.iter() {
|
||||||
// create two batches in order to trigger concatenate operation
|
let (batch_size, seq_length) = shape;
|
||||||
// in case decode bs=1 create one batch
|
let mut batches: Vec<Batch> = vec![
|
||||||
let batches: Vec<Batch> = vec![
|
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
*shape,
|
*shape,
|
||||||
&mut id_counter,
|
&mut id_counter,
|
||||||
@ -156,9 +158,22 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
false,
|
false,
|
||||||
);
|
)
|
||||||
num_batches
|
|
||||||
];
|
];
|
||||||
|
// if possible, create second batch in order to trigger concatenate operation
|
||||||
|
if *batch_size < max_decode_batch_size {
|
||||||
|
batches.push(
|
||||||
|
self.create_warmup_batch(
|
||||||
|
(1, *seq_length),
|
||||||
|
&mut id_counter,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
seq_bucket_size,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batches,
|
batches,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
@ -168,7 +183,7 @@ impl Client {
|
|||||||
let _response = self.stub.warmup(request).await?.into_inner();
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
}
|
}
|
||||||
|
|
||||||
//Send batches with deafult params to warm up Greedy search
|
// send batches with default params to warm up Greedy search
|
||||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
||||||
for batch_size in &batch_sizes {
|
for batch_size in &batch_sizes {
|
||||||
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||||
@ -182,8 +197,7 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
true,
|
true,
|
||||||
);
|
)
|
||||||
num_batches
|
|
||||||
];
|
];
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batches,
|
batches,
|
||||||
|
@ -1114,6 +1114,8 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# if decode bs is 1 warmup ends here
|
# if decode bs is 1 warmup ends here
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
|
while decode_batch is not None:
|
||||||
|
_, decode_batch, _ = self.generate_token([decode_batch])
|
||||||
return
|
return
|
||||||
|
|
||||||
# prefill
|
# prefill
|
||||||
|
Loading…
Reference in New Issue
Block a user