re-create slots

This commit is contained in:
OlivierDehaene 2024-10-02 14:17:26 +02:00
parent 4db5e7dde6
commit b49978ff67
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
6 changed files with 57 additions and 27 deletions

View File

@ -237,7 +237,11 @@ impl Client {
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let request = tonic::Request::new(DecodeRequest {
batch: None,
batches,
})
.inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,

View File

@ -200,6 +200,8 @@ pub(crate) async fn batching_task(
(min_size, max_size, max_batch_prefill_tokens)
};
let mut additional_batch = None;
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
@ -210,31 +212,40 @@ pub(crate) async fn batching_task(
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
let counter = if support_chunking {
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
};
counter.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
if support_chunking {
entries.extend(new_entries);
batches.push(new_cached_batch);
additional_batch = Some(new_batch);
} else {
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
}
@ -252,7 +263,7 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
cached_batch = decode(&mut client, additional_batch, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
@ -306,6 +317,7 @@ async fn prefill(
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
@ -313,7 +325,7 @@ async fn decode(
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
match client.decode(batch, batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries

View File

@ -235,9 +235,10 @@ impl Client {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,

View File

@ -167,12 +167,13 @@ impl ShardedClient {
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.map(|client| Box::pin(client.decode(batch.clone(), batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =

View File

@ -243,6 +243,8 @@ message PrefillResponse {
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
/// Optional Batch
optional Batch batch = 2;
}
message DecodeResponse {

View File

@ -179,6 +179,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) == 0:
raise ValueError("All batches are empty")
if self.model.support_chunking:
if request.HasField("batch"):
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
self.model.dtype,
self.model.device,
)
batches.append(batch)
if len(batches) > 1:
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches)