Add warmup for all possible shapes for prefill #49 (#81)

This commit is contained in:
Karol Damaszke 2024-02-28 10:40:13 +01:00 committed by GitHub
parent 31bed905d4
commit 2122acc60f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 129 additions and 38 deletions

View File

@ -86,7 +86,7 @@ Environment Variables Added:
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
</div>

View File

@ -213,7 +213,7 @@ message DecodeResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
repeated Batch batches = 1;
}
/// Empty response

View File

@ -9,6 +9,7 @@ homepage.workspace = true
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.12"
rand = "0.8.5"
thiserror = "^1.0"
tokio = { version = "^1.32", features = ["sync"] }
tonic = "^0.10"

View File

@ -2,8 +2,10 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use std::env;
use rand::{distributions::Uniform, Rng};
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use std::cmp;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
@ -105,48 +107,115 @@ impl Client {
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
if !warmup_enabled {
return Ok(None);
}
let read_env_var = |key: &str, default: u32| -> u32 {
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
};
// get all possible prefill batch sizes
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 1);
let batch_sizes: Vec<u32> = (1..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
// get all possible sequence lengths for prefill
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
// execute batch for each combination of batch size and sequence length
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
for batch_size in &batch_sizes {
for seq_length in &seq_lengths {
shapes.push((*batch_size, *seq_length));
}
}
let mut id_counter: u64 = 0;
for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation
let batches: Vec<Batch> = vec![
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size),
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size)
];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
Ok(None) // No support for maximum total tokens
}
#[instrument(skip_all)]
fn create_warmup_batch(
&mut self,
shape: (u32, u32),
id_counter: &mut u64,
max_input_length: u32,
max_total_tokens: u32,
seq_bucket_size: u32,
) -> Batch {
*id_counter += 1;
let (batch_size, input_length) = shape;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
for request_id in 0..batch_size {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate,
// Set sampling parameters to also take these ops into account in the max memory
id: *id_counter + request_id as u64,
inputs: self.get_random_input(input_length, seq_bucket_size),
truncate: max_input_length,
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens: 10,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
prefill_logprobs: false,
top_n_tokens: 0,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
Batch {
id: *id_counter,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
max_tokens: max_total_tokens,
}
}
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
#[instrument(skip_all)]
fn get_random_input(
&mut self,
input_length: u32,
seq_bucket_size: u32,
) -> String {
let skip_tokenizer_in_tgi: bool = env::var("SKIP_TOKENIZER_IN_TGI")
.ok()
.map_or(false, |value| value.to_lowercase() == "true");
if skip_tokenizer_in_tgi {
// generate random tokens
let mut rng = rand::thread_rng();
let range = Uniform::new(2, 8192);
let tokens = input_length - seq_bucket_size / 2;
(0..tokens)
.map(|_| rng.sample(&range).to_string())
.collect::<Vec<String>>()
.join(", ")
} else {
// repeat test string to get expected input shape
let bucket_id = input_length / seq_bucket_size;
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
"_test ".to_string().repeat(repeats as usize)
}
}
/// Generate one token for each request in the given batch

View File

@ -990,3 +990,25 @@ class CausalLM(Model):
else:
self.hb_profiler.step()
return generations, batch if not stopped else None
def warmup(self, batches: List[CausalLMBatch]) -> None:
self.shifting_warmup()
if len(batches) < 2:
return
# prefill
_, prefill_batch = self.generate_token([batches[0]])
# decode
_, decode_batch = self.generate_token([prefill_batch])
# prefill
_, prefill_batch = self.generate_token([batches[1]])
# concatenate and decode
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
# decodes
while decode_batch is not None:
_, decode_batch = self.generate_token([decode_batch])
def shifting_warmup(self) -> None:
# TODO: add warmup for all possible shift variants
pass

View File

@ -67,16 +67,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
with self.profiler.record_event("external", "warmup"):
# batch = self.model.batch_type.from_pb(
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
# )
# max_supported_total_tokens = self.model.warmup(batch)
def batch_from_pb(batch):
return self.model.batch_type.from_pb(
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
)
with self.profiler.record_event("external", "warmup"):
batches = [batch_from_pb(batch) for batch in request.batches]
self.model.warmup(batches)
# return generate_pb2.WarmupResponse(
# max_supported_total_tokens=max_supported_total_tokens
# )
logger.warning("Warmup is not enabled on HPU.")
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):