mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
parent
31bed905d4
commit
2122acc60f
@ -86,7 +86,7 @@ Environment Variables Added:
|
||||
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
|
||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
||||
|
||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||
</div>
|
||||
|
||||
|
||||
|
@ -213,7 +213,7 @@ message DecodeResponse {
|
||||
|
||||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
repeated Batch batches = 1;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
|
@ -9,6 +9,7 @@ homepage.workspace = true
|
||||
futures = "^0.3"
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
prost = "^0.12"
|
||||
rand = "0.8.5"
|
||||
thiserror = "^1.0"
|
||||
tokio = { version = "^1.32", features = ["sync"] }
|
||||
tonic = "^0.10"
|
||||
|
@ -2,8 +2,10 @@
|
||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v1::*;
|
||||
use crate::Result;
|
||||
use std::env;
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use std::cmp::min;
|
||||
use std::cmp;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
@ -105,48 +107,115 @@ impl Client {
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
||||
if !warmup_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let read_env_var = |key: &str, default: u32| -> u32 {
|
||||
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
|
||||
};
|
||||
|
||||
// get all possible prefill batch sizes
|
||||
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
||||
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 1);
|
||||
let batch_sizes: Vec<u32> = (1..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
||||
|
||||
// get all possible sequence lengths for prefill
|
||||
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
||||
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
||||
|
||||
// execute batch for each combination of batch size and sequence length
|
||||
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
||||
for batch_size in &batch_sizes {
|
||||
for seq_length in &seq_lengths {
|
||||
shapes.push((*batch_size, *seq_length));
|
||||
}
|
||||
}
|
||||
|
||||
let mut id_counter: u64 = 0;
|
||||
for shape in shapes.iter() {
|
||||
// create two batches in order to trigger concatenate operation
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size),
|
||||
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size)
|
||||
];
|
||||
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||
let _response = self.stub.warmup(request).await?.into_inner();
|
||||
}
|
||||
|
||||
Ok(None) // No support for maximum total tokens
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn create_warmup_batch(
|
||||
&mut self,
|
||||
shape: (u32, u32),
|
||||
id_counter: &mut u64,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
seq_bucket_size: u32,
|
||||
) -> Batch {
|
||||
*id_counter += 1;
|
||||
let (batch_size, input_length) = shape;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
for request_id in 0..batch_size {
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
id: *id_counter + request_id as u64,
|
||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||
truncate: max_input_length,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
watermark: true,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
max_new_tokens: 10,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
prefill_logprobs: false,
|
||||
top_n_tokens: 0,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
Batch {
|
||||
id: *id_counter,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
};
|
||||
max_tokens: max_total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
#[instrument(skip_all)]
|
||||
fn get_random_input(
|
||||
&mut self,
|
||||
input_length: u32,
|
||||
seq_bucket_size: u32,
|
||||
) -> String {
|
||||
let skip_tokenizer_in_tgi: bool = env::var("SKIP_TOKENIZER_IN_TGI")
|
||||
.ok()
|
||||
.map_or(false, |value| value.to_lowercase() == "true");
|
||||
if skip_tokenizer_in_tgi {
|
||||
// generate random tokens
|
||||
let mut rng = rand::thread_rng();
|
||||
let range = Uniform::new(2, 8192);
|
||||
let tokens = input_length - seq_bucket_size / 2;
|
||||
(0..tokens)
|
||||
.map(|_| rng.sample(&range).to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
} else {
|
||||
// repeat test string to get expected input shape
|
||||
let bucket_id = input_length / seq_bucket_size;
|
||||
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
|
||||
"_test ".to_string().repeat(repeats as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
|
@ -990,3 +990,25 @@ class CausalLM(Model):
|
||||
else:
|
||||
self.hb_profiler.step()
|
||||
return generations, batch if not stopped else None
|
||||
|
||||
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||
self.shifting_warmup()
|
||||
|
||||
if len(batches) < 2:
|
||||
return
|
||||
|
||||
# prefill
|
||||
_, prefill_batch = self.generate_token([batches[0]])
|
||||
# decode
|
||||
_, decode_batch = self.generate_token([prefill_batch])
|
||||
# prefill
|
||||
_, prefill_batch = self.generate_token([batches[1]])
|
||||
# concatenate and decode
|
||||
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
|
||||
# decodes
|
||||
while decode_batch is not None:
|
||||
_, decode_batch = self.generate_token([decode_batch])
|
||||
|
||||
def shifting_warmup(self) -> None:
|
||||
# TODO: add warmup for all possible shift variants
|
||||
pass
|
||||
|
@ -67,16 +67,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
with self.profiler.record_event("external", "warmup"):
|
||||
# batch = self.model.batch_type.from_pb(
|
||||
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
# )
|
||||
# max_supported_total_tokens = self.model.warmup(batch)
|
||||
def batch_from_pb(batch):
|
||||
return self.model.batch_type.from_pb(
|
||||
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||
)
|
||||
|
||||
with self.profiler.record_event("external", "warmup"):
|
||||
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||
self.model.warmup(batches)
|
||||
|
||||
# return generate_pb2.WarmupResponse(
|
||||
# max_supported_total_tokens=max_supported_total_tokens
|
||||
# )
|
||||
logger.warning("Warmup is not enabled on HPU.")
|
||||
return generate_pb2.WarmupResponse()
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
|
Loading…
Reference in New Issue
Block a user