From 2122acc60f328cc3e58a0f8dd7626415b99271be Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Wed, 28 Feb 2024 10:40:13 +0100 Subject: [PATCH] Add warmup for all possible shapes for prefill #49 (#81) --- README.md | 2 +- proto/generate.proto | 2 +- router/client/Cargo.toml | 1 + router/client/src/client.rs | 123 ++++++++++++++---- .../models/causal_lm.py | 22 ++++ server/text_generation_server/server.py | 17 ++- 6 files changed, 129 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index f7208bf8..6afb47b2 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ Environment Variables Added: | PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command | | SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command | | TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command | - +| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command | diff --git a/proto/generate.proto b/proto/generate.proto index c873e661..c7f9f3c1 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -213,7 +213,7 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on - Batch batch = 1; + repeated Batch batches = 1; } /// Empty response diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index d0131784..bc4ae72e 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -9,6 +9,7 @@ homepage.workspace = true futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } prost = "^0.12" +rand = "0.8.5" thiserror = "^1.0" tokio = { version = "^1.32", features = ["sync"] } tonic = "^0.10" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 341e70fd..530e6df3 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -2,8 +2,10 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::*; use crate::Result; +use std::env; +use rand::{distributions::Uniform, Rng}; use grpc_metadata::InjectTelemetryContext; -use std::cmp::min; +use std::cmp; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -105,48 +107,115 @@ impl Client { max_prefill_tokens: u32, max_total_tokens: u32, ) -> Result> { - let mut n_tokens = 0; + let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); + if !warmup_enabled { + return Ok(None); + } + + let read_env_var = |key: &str, default: u32| -> u32 { + env::var(key).ok().map_or(default, |value| value.parse::().unwrap()) + }; + + // get all possible prefill batch sizes + let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length; + let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 1); + let batch_sizes: Vec = (1..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect(); + + // get all possible sequence lengths for prefill + let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); + let seq_lengths: Vec = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect(); + + // execute batch for each combination of batch size and sequence length + let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len()); + for batch_size in &batch_sizes { + for seq_length in &seq_lengths { + shapes.push((*batch_size, *seq_length)); + } + } + + let mut id_counter: u64 = 0; + for shape in shapes.iter() { + // create two batches in order to trigger concatenate operation + let batches: Vec = vec![ + self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size), + self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size) + ]; + let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); + let _response = self.stub.warmup(request).await?.into_inner(); + } + + Ok(None) // No support for maximum total tokens + } + + #[instrument(skip_all)] + fn create_warmup_batch( + &mut self, + shape: (u32, u32), + id_counter: &mut u64, + max_input_length: u32, + max_total_tokens: u32, + seq_bucket_size: u32, + ) -> Batch { + *id_counter += 1; + let (batch_size, input_length) = shape; let mut requests = Vec::new(); - // Create requests - while n_tokens < max_prefill_tokens { - let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + for request_id in 0..batch_size { requests.push(Request { - id: 0, - // We truncate the input on the server side to be sure that it has the correct size - inputs: "_test ".to_string().repeat(max_input_length as usize), - truncate, - // Set sampling parameters to also take these ops into account in the max memory + id: *id_counter + request_id as u64, + inputs: self.get_random_input(input_length, seq_bucket_size), + truncate: max_input_length, parameters: Some(NextTokenChooserParameters { - temperature: 0.9, - top_k: 10, - top_p: 0.9, - typical_p: 0.9, + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, do_sample: false, seed: 0, - repetition_penalty: 1.2, - watermark: true, + repetition_penalty: 1.0, + watermark: false, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: max_total_tokens - truncate, + max_new_tokens: 10, stop_sequences: vec![], ignore_eos_token: true, }), - prefill_logprobs: true, - top_n_tokens: 20, + prefill_logprobs: false, + top_n_tokens: 0, }); - n_tokens += max_input_length; } - let batch = Batch { - id: 0, + Batch { + id: *id_counter, size: requests.len() as u32, requests, - max_tokens: 0, - }; + max_tokens: max_total_tokens, + } + } - let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context(); - let response = self.stub.warmup(request).await?.into_inner(); - Ok(response.max_supported_total_tokens) + #[instrument(skip_all)] + fn get_random_input( + &mut self, + input_length: u32, + seq_bucket_size: u32, + ) -> String { + let skip_tokenizer_in_tgi: bool = env::var("SKIP_TOKENIZER_IN_TGI") + .ok() + .map_or(false, |value| value.to_lowercase() == "true"); + if skip_tokenizer_in_tgi { + // generate random tokens + let mut rng = rand::thread_rng(); + let range = Uniform::new(2, 8192); + let tokens = input_length - seq_bucket_size / 2; + (0..tokens) + .map(|_| rng.sample(&range).to_string()) + .collect::>() + .join(", ") + } else { + // repeat test string to get expected input shape + let bucket_id = input_length / seq_bucket_size; + let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2); + "_test ".to_string().repeat(repeats as usize) + } } /// Generate one token for each request in the given batch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ac7e3176..42da4d06 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -990,3 +990,25 @@ class CausalLM(Model): else: self.hb_profiler.step() return generations, batch if not stopped else None + + def warmup(self, batches: List[CausalLMBatch]) -> None: + self.shifting_warmup() + + if len(batches) < 2: + return + + # prefill + _, prefill_batch = self.generate_token([batches[0]]) + # decode + _, decode_batch = self.generate_token([prefill_batch]) + # prefill + _, prefill_batch = self.generate_token([batches[1]]) + # concatenate and decode + _, decode_batch = self.generate_token([decode_batch, prefill_batch]) + # decodes + while decode_batch is not None: + _, decode_batch = self.generate_token([decode_batch]) + + def shifting_warmup(self) -> None: + # TODO: add warmup for all possible shift variants + pass diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 6f3e49f2..841bda93 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -67,16 +67,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - with self.profiler.record_event("external", "warmup"): - # batch = self.model.batch_type.from_pb( - # request.batch, self.model.tokenizer, self.model.dtype, self.model.device - # ) - # max_supported_total_tokens = self.model.warmup(batch) + def batch_from_pb(batch): + return self.model.batch_type.from_pb( + batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi + ) + + with self.profiler.record_event("external", "warmup"): + batches = [batch_from_pb(batch) for batch in request.batches] + self.model.warmup(batches) - # return generate_pb2.WarmupResponse( - # max_supported_total_tokens=max_supported_total_tokens - # ) - logger.warning("Warmup is not enabled on HPU.") return generate_pb2.WarmupResponse() async def Prefill(self, request, context):