mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Warmup greedy search in next token chooser (#109)
This commit is contained in:
parent
d752317b5f
commit
2b1581edac
@ -137,13 +137,26 @@ impl Client {
|
||||
for shape in shapes.iter() {
|
||||
// create two batches in order to trigger concatenate operation
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size),
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size)
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false),
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false)
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
//Send batches with deafult params to warm up Greedy search
|
||||
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
||||
for batch_size in &batch_sizes {
|
||||
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||
}
|
||||
for greedy_shape in greedy_shapes.iter() {
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
||||
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
Ok(None) // No support for maximum total tokens
|
||||
}
|
||||
|
||||
@ -155,16 +168,25 @@ impl Client {
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
seq_bucket_size: u32,
|
||||
default_params: bool,
|
||||
) -> Batch {
|
||||
*id_counter += 1;
|
||||
let (batch_size, input_length) = shape;
|
||||
let mut requests = Vec::new();
|
||||
for request_id in 0..batch_size {
|
||||
requests.push(Request {
|
||||
id: *id_counter + request_id as u64,
|
||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||
truncate: max_input_length,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
let req_params = if default_params {
|
||||
Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
})
|
||||
} else {
|
||||
Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
@ -173,7 +195,13 @@ impl Client {
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
watermark: true,
|
||||
}),
|
||||
})
|
||||
};
|
||||
requests.push(Request {
|
||||
id: *id_counter + request_id as u64,
|
||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||
truncate: max_input_length,
|
||||
parameters: req_params,
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
||||
stop_sequences: vec![],
|
||||
|
Loading…
Reference in New Issue
Block a user