Warmup greedy search in next token chooser (#109)

This commit is contained in:
jkaniecki 2024-03-22 23:43:20 +01:00 committed by GitHub
parent d752317b5f
commit 2b1581edac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -137,13 +137,26 @@ impl Client {
for shape in shapes.iter() { for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation // create two batches in order to trigger concatenate operation
let batches: Vec<Batch> = vec![ let batches: Vec<Batch> = vec![
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size), self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false),
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size) self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false)
]; ];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner(); let _response = self.stub.warmup(request).await?.into_inner();
} }
//Send batches with deafult params to warm up Greedy search
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
for batch_size in &batch_sizes {
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
}
for greedy_shape in greedy_shapes.iter() {
let batches: Vec<Batch> = vec![
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
Ok(None) // No support for maximum total tokens Ok(None) // No support for maximum total tokens
} }
@ -155,16 +168,25 @@ impl Client {
max_input_length: u32, max_input_length: u32,
max_total_tokens: u32, max_total_tokens: u32,
seq_bucket_size: u32, seq_bucket_size: u32,
default_params: bool,
) -> Batch { ) -> Batch {
*id_counter += 1; *id_counter += 1;
let (batch_size, input_length) = shape; let (batch_size, input_length) = shape;
let mut requests = Vec::new(); let mut requests = Vec::new();
for request_id in 0..batch_size { for request_id in 0..batch_size {
requests.push(Request { let req_params = if default_params {
id: *id_counter + request_id as u64, Some(NextTokenChooserParameters {
inputs: self.get_random_input(input_length, seq_bucket_size), temperature: 1.0,
truncate: max_input_length, top_k: 0,
parameters: Some(NextTokenChooserParameters { top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
})
} else {
Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
top_k: 10, top_k: 10,
top_p: 0.9, top_p: 0.9,
@ -173,7 +195,13 @@ impl Client {
seed: 0, seed: 0,
repetition_penalty: 1.2, repetition_penalty: 1.2,
watermark: true, watermark: true,
}), })
};
requests.push(Request {
id: *id_counter + request_id as u64,
inputs: self.get_random_input(input_length, seq_bucket_size),
truncate: max_input_length,
parameters: req_params,
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length), max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
stop_sequences: vec![], stop_sequences: vec![],