mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
parent
31bed905d4
commit
2122acc60f
@ -86,7 +86,7 @@ Environment Variables Added:
|
|||||||
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
|
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
|
||||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||||
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
||||||
|
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,7 +213,7 @@ message DecodeResponse {
|
|||||||
|
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
repeated Batch batches = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty response
|
/// Empty response
|
||||||
|
@ -9,6 +9,7 @@ homepage.workspace = true
|
|||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
grpc-metadata = { path = "../grpc-metadata" }
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
prost = "^0.12"
|
prost = "^0.12"
|
||||||
|
rand = "0.8.5"
|
||||||
thiserror = "^1.0"
|
thiserror = "^1.0"
|
||||||
tokio = { version = "^1.32", features = ["sync"] }
|
tokio = { version = "^1.32", features = ["sync"] }
|
||||||
tonic = "^0.10"
|
tonic = "^0.10"
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v1::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
use std::env;
|
||||||
|
use rand::{distributions::Uniform, Rng};
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
use std::cmp::min;
|
use std::cmp;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -105,48 +107,115 @@ impl Client {
|
|||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
||||||
|
if !warmup_enabled {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let read_env_var = |key: &str, default: u32| -> u32 {
|
||||||
|
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
|
||||||
|
};
|
||||||
|
|
||||||
|
// get all possible prefill batch sizes
|
||||||
|
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
||||||
|
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 1);
|
||||||
|
let batch_sizes: Vec<u32> = (1..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
||||||
|
|
||||||
|
// get all possible sequence lengths for prefill
|
||||||
|
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
||||||
|
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
||||||
|
|
||||||
|
// execute batch for each combination of batch size and sequence length
|
||||||
|
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
||||||
|
for batch_size in &batch_sizes {
|
||||||
|
for seq_length in &seq_lengths {
|
||||||
|
shapes.push((*batch_size, *seq_length));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut id_counter: u64 = 0;
|
||||||
|
for shape in shapes.iter() {
|
||||||
|
// create two batches in order to trigger concatenate operation
|
||||||
|
let batches: Vec<Batch> = vec![
|
||||||
|
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size),
|
||||||
|
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size)
|
||||||
|
];
|
||||||
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
||||||
|
let _response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None) // No support for maximum total tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn create_warmup_batch(
|
||||||
|
&mut self,
|
||||||
|
shape: (u32, u32),
|
||||||
|
id_counter: &mut u64,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
seq_bucket_size: u32,
|
||||||
|
) -> Batch {
|
||||||
|
*id_counter += 1;
|
||||||
|
let (batch_size, input_length) = shape;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
// Create requests
|
for request_id in 0..batch_size {
|
||||||
while n_tokens < max_prefill_tokens {
|
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: *id_counter + request_id as u64,
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
truncate: max_input_length,
|
||||||
truncate,
|
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 1.0,
|
||||||
top_k: 10,
|
top_k: 0,
|
||||||
top_p: 0.9,
|
top_p: 1.0,
|
||||||
typical_p: 0.9,
|
typical_p: 1.0,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.0,
|
||||||
watermark: true,
|
watermark: false,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: 10,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: false,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 0,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let batch = Batch {
|
Batch {
|
||||||
id: 0,
|
id: *id_counter,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: 0,
|
max_tokens: max_total_tokens,
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
#[instrument(skip_all)]
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
fn get_random_input(
|
||||||
Ok(response.max_supported_total_tokens)
|
&mut self,
|
||||||
|
input_length: u32,
|
||||||
|
seq_bucket_size: u32,
|
||||||
|
) -> String {
|
||||||
|
let skip_tokenizer_in_tgi: bool = env::var("SKIP_TOKENIZER_IN_TGI")
|
||||||
|
.ok()
|
||||||
|
.map_or(false, |value| value.to_lowercase() == "true");
|
||||||
|
if skip_tokenizer_in_tgi {
|
||||||
|
// generate random tokens
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let range = Uniform::new(2, 8192);
|
||||||
|
let tokens = input_length - seq_bucket_size / 2;
|
||||||
|
(0..tokens)
|
||||||
|
.map(|_| rng.sample(&range).to_string())
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(", ")
|
||||||
|
} else {
|
||||||
|
// repeat test string to get expected input shape
|
||||||
|
let bucket_id = input_length / seq_bucket_size;
|
||||||
|
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
|
||||||
|
"_test ".to_string().repeat(repeats as usize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
@ -990,3 +990,25 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
self.hb_profiler.step()
|
self.hb_profiler.step()
|
||||||
return generations, batch if not stopped else None
|
return generations, batch if not stopped else None
|
||||||
|
|
||||||
|
def warmup(self, batches: List[CausalLMBatch]) -> None:
|
||||||
|
self.shifting_warmup()
|
||||||
|
|
||||||
|
if len(batches) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
# prefill
|
||||||
|
_, prefill_batch = self.generate_token([batches[0]])
|
||||||
|
# decode
|
||||||
|
_, decode_batch = self.generate_token([prefill_batch])
|
||||||
|
# prefill
|
||||||
|
_, prefill_batch = self.generate_token([batches[1]])
|
||||||
|
# concatenate and decode
|
||||||
|
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
|
||||||
|
# decodes
|
||||||
|
while decode_batch is not None:
|
||||||
|
_, decode_batch = self.generate_token([decode_batch])
|
||||||
|
|
||||||
|
def shifting_warmup(self) -> None:
|
||||||
|
# TODO: add warmup for all possible shift variants
|
||||||
|
pass
|
||||||
|
@ -67,16 +67,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
with self.profiler.record_event("external", "warmup"):
|
def batch_from_pb(batch):
|
||||||
# batch = self.model.batch_type.from_pb(
|
return self.model.batch_type.from_pb(
|
||||||
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||||
# )
|
)
|
||||||
# max_supported_total_tokens = self.model.warmup(batch)
|
|
||||||
|
with self.profiler.record_event("external", "warmup"):
|
||||||
|
batches = [batch_from_pb(batch) for batch in request.batches]
|
||||||
|
self.model.warmup(batches)
|
||||||
|
|
||||||
# return generate_pb2.WarmupResponse(
|
|
||||||
# max_supported_total_tokens=max_supported_total_tokens
|
|
||||||
# )
|
|
||||||
logger.warning("Warmup is not enabled on HPU.")
|
|
||||||
return generate_pb2.WarmupResponse()
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
|
Loading…
Reference in New Issue
Block a user