From b49978ff67cb171223562efa98f9a717aa215859 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:17:26 +0200 Subject: [PATCH] re-create slots --- backends/client/src/v3/client.rs | 6 ++- backends/v3/src/backend.rs | 60 ++++++++++++++---------- backends/v3/src/client/grpc_client.rs | 3 +- backends/v3/src/client/sharded_client.rs | 3 +- proto/v3/generate.proto | 2 + server/text_generation_server/server.py | 10 ++++ 6 files changed, 57 insertions(+), 27 deletions(-) diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 61d1ea1b..5191f8dd 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -237,7 +237,11 @@ impl Client { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let request = tonic::Request::new(DecodeRequest { + batch: None, + batches, + }) + .inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index bfe7932f..183f4e52 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -200,6 +200,8 @@ pub(crate) async fn batching_task( (min_size, max_size, max_batch_prefill_tokens) }; + let mut additional_batch = None; + // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, prefill_token_budget, token_budget) @@ -210,31 +212,40 @@ pub(crate) async fn batching_task( metrics::counter!("tgi_batch_concat", "reason" => "backpressure") .increment(1); } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); + let counter = if support_chunking { + metrics::counter!("tgi_batch_concat", "reason" => "chunking") + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + }; + counter.increment(1); } - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { + if support_chunking { entries.extend(new_entries); - batches.push(new_cached_batch); + additional_batch = Some(new_batch); + } else { + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } } } @@ -252,7 +263,7 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + cached_batch = decode(&mut client, additional_batch, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -306,6 +317,7 @@ async fn prefill( #[instrument(skip_all)] async fn decode( client: &mut ShardedClient, + batch: Option, batches: Vec, entries: &mut IntMap, ) -> Option { @@ -313,7 +325,7 @@ async fn decode( let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - match client.decode(batches).await { + match client.decode(batch, batches).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 3b4432a7..ab93db9b 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -235,9 +235,10 @@ impl Client { #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, + batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 97a1eab6..8af4b26f 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -167,12 +167,13 @@ impl ShardedClient { #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] pub async fn decode( &mut self, + batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.decode(batches.clone()))) + .map(|client| Box::pin(client.decode(batch.clone(), batches.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, DecodeTimings)>> = diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index cfb92ba8..15a93ac9 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -243,6 +243,8 @@ message PrefillResponse { message DecodeRequest { /// Cached batches repeated CachedBatch batches = 1; + /// Optional Batch + optional Batch batch = 2; } message DecodeResponse { diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index bd4b3a53..d89df966 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -179,6 +179,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) == 0: raise ValueError("All batches are empty") + if self.model.support_chunking: + if request.HasField("batch"): + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.dtype, + self.model.device, + ) + batches.append(batch) + if len(batches) > 1: start_concat = time.time_ns() batch = self.model.batch_type.concatenate(batches)