mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Warmup greedy search in next token chooser (#109)
This commit is contained in:
parent
d752317b5f
commit
2b1581edac
@ -137,13 +137,26 @@ impl Client {
|
|||||||
for shape in shapes.iter() {
|
for shape in shapes.iter() {
|
||||||
// create two batches in order to trigger concatenate operation
|
// create two batches in order to trigger concatenate operation
|
||||||
let batches: Vec<Batch> = vec![
|
let batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size),
|
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false),
|
||||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size)
|
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false)
|
||||||
];
|
];
|
||||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||||
let _response = self.stub.warmup(request).await?.into_inner();
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Send batches with deafult params to warm up Greedy search
|
||||||
|
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
||||||
|
for batch_size in &batch_sizes {
|
||||||
|
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
||||||
|
}
|
||||||
|
for greedy_shape in greedy_shapes.iter() {
|
||||||
|
let batches: Vec<Batch> = vec![
|
||||||
|
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
||||||
|
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
|
||||||
|
];
|
||||||
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||||
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
}
|
||||||
Ok(None) // No support for maximum total tokens
|
Ok(None) // No support for maximum total tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,16 +168,25 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
seq_bucket_size: u32,
|
seq_bucket_size: u32,
|
||||||
|
default_params: bool,
|
||||||
) -> Batch {
|
) -> Batch {
|
||||||
*id_counter += 1;
|
*id_counter += 1;
|
||||||
let (batch_size, input_length) = shape;
|
let (batch_size, input_length) = shape;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
for request_id in 0..batch_size {
|
for request_id in 0..batch_size {
|
||||||
requests.push(Request {
|
let req_params = if default_params {
|
||||||
id: *id_counter + request_id as u64,
|
Some(NextTokenChooserParameters {
|
||||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
temperature: 1.0,
|
||||||
truncate: max_input_length,
|
top_k: 0,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
watermark: false,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
top_k: 10,
|
top_k: 10,
|
||||||
top_p: 0.9,
|
top_p: 0.9,
|
||||||
@ -173,7 +195,13 @@ impl Client {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
}),
|
})
|
||||||
|
};
|
||||||
|
requests.push(Request {
|
||||||
|
id: *id_counter + request_id as u64,
|
||||||
|
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||||
|
truncate: max_input_length,
|
||||||
|
parameters: req_params,
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
|
Loading…
Reference in New Issue
Block a user