diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 4a119e86..d40a2e8d 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,6 +1,6 @@ use std::time::{Duration, Instant}; use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, + Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; use tokenizers::{Tokenizer, TruncationDirection}; @@ -126,7 +126,7 @@ async fn prefill( batch_size: u32, decode_length: u32, client: &mut ShardedClient, -) -> Result<(Prefill, Batch), ClientError> { +) -> Result<(Prefill, CachedBatch), ClientError> { // Create requests let requests = (0..batch_size) .map(|id| Request { @@ -180,7 +180,7 @@ async fn prefill( } /// Run a full decode -async fn decode(batch: Batch, client: &mut ShardedClient) -> Result { +async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result { let mut decode_length = 0; let batch_size = batch.size; diff --git a/proto/generate.proto b/proto/generate.proto index 894d7bc1..0c40e5bb 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -100,6 +100,17 @@ message Batch { uint32 max_tokens = 4; } +message CachedBatch { + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; +} + enum FinishReason { FINISH_REASON_LENGTH = 0; FINISH_REASON_EOS_TOKEN = 1; @@ -140,19 +151,19 @@ message Generation { /// Is it a special token bool token_is_special = 6; /// Complete generated text - GeneratedText generated_text = 7; + optional GeneratedText generated_text = 7; } message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated Request keep_requests = 2; + repeated uint64 request_ids = 2; } message FilterBatchResponse { /// Filtered Batch (cached) - Batch batch = 1; + CachedBatch batch = 1; } @@ -165,17 +176,17 @@ message PrefillResponse { /// Generation repeated Generation generations = 1; /// Next batch (cached) - optional Batch batch = 2; + optional CachedBatch batch = 2; } message DecodeRequest { /// Cached batches - repeated Batch batches = 1; + repeated CachedBatch batches = 1; } message DecodeResponse { /// Decodes repeated Generation generations = 1; /// Next batch (cached) - optional Batch batch = 2; + optional CachedBatch batch = 2; } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index bf1b6b58..81f023ef 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -83,11 +83,11 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - keep_requests: Vec, - ) -> Result> { + request_ids: Vec, + ) -> Result> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - keep_requests, + request_ids, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); @@ -99,7 +99,10 @@ impl Client { /// Returns Generation for each request in batch /// and the next cached batch #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] - pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option)> { let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok((response.generations, response.batch)) @@ -112,8 +115,8 @@ impl Client { #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, - batches: Vec, - ) -> Result<(Vec, Option)> { + batches: Vec, + ) -> Result<(Vec, Option)> { let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok((response.generations, response.batch)) diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 66f2055a..f334be21 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,8 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, - Request, StoppingCriteriaParameters, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, + PrefillTokens, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 60b81fe6..b81eed46 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,5 @@ /// Multi shard Client -use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; +use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; @@ -76,12 +76,12 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - keep_requests: Vec, - ) -> Result> { + request_ids: Vec, + ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone()))) + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() @@ -92,13 +92,16 @@ impl ShardedClient { /// Returns Generation for each request in batch /// and the next cached batch #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] - pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); - let results: Result, Option)>> = + let results: Result, Option)>> = join_all(futures).await.into_iter().collect(); merge_generations(results?) } @@ -110,14 +113,14 @@ impl ShardedClient { #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, - batches: Vec, - ) -> Result<(Vec, Option)> { + batches: Vec, + ) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); - let results: Result, Option)>> = + let results: Result, Option)>> = join_all(futures).await.into_iter().collect(); merge_generations(results?) } @@ -125,8 +128,8 @@ impl ShardedClient { /// Merge generations from the different model shards fn merge_generations( - mut results: Vec<(Vec, Option)>, -) -> Result<(Vec, Option)> { + mut results: Vec<(Vec, Option)>, +) -> Result<(Vec, Option)> { let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?; for (mut shard_generations, _) in results.into_iter() { diff --git a/router/src/infer.rs b/router/src/infer.rs index 313ec3e1..00fa2818 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -12,7 +12,7 @@ use std::sync::{ Arc, }; use text_generation_client::{ - Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, + Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; use thiserror::Error; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; @@ -352,7 +352,7 @@ async fn prefill( batch: Batch, entries: &mut IntMap, generation_health: &Arc, -) -> Option { +) -> Option { let start_time = Instant::now(); let batch_id = batch.id; metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); @@ -386,10 +386,10 @@ async fn prefill( #[instrument(skip_all)] async fn decode( client: &mut ShardedClient, - batches: Vec, + batches: Vec, entries: &mut IntMap, generation_health: &Arc, -) -> Option { +) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); @@ -425,9 +425,9 @@ async fn decode( #[instrument(skip_all)] async fn filter_batch( client: &mut ShardedClient, - next_batch: Option, + next_batch: Option, entries: &IntMap, -) -> Option { +) -> Option { let mut batch = next_batch?; // No need to filter @@ -438,9 +438,9 @@ async fn filter_batch( let id = batch.id; // Retain only requests that are still in entries - batch.requests.retain(|r| entries.contains_key(&r.id)); + batch.request_ids.retain(|id| entries.contains_key(id)); - if batch.requests.is_empty() { + if batch.request_ids.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache @@ -450,7 +450,7 @@ async fn filter_batch( } else { // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.requests).await.unwrap() + client.filter_batch(id, batch.request_ids).await.unwrap() } } diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index f0adab97..105b3573 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter([next_batch.requests[0]]) + next_batch = next_batch.filter([next_batch.requests[0].id]) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -286,7 +286,7 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) for _ in range( default_bloom_batch.stopping_criterias[0].max_new_tokens @@ -309,7 +309,7 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[1].id]) for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f1f13e4b..d8d1bd16 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter([next_batch.requests[0]]) + next_batch = next_batch.filter([next_batch.requests[0].id]) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -285,7 +285,7 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) for _ in range( default_causal_lm_batch.stopping_criterias[0].max_new_tokens @@ -306,7 +306,7 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[1].id]) for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index ba769e75..8fdeee60 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0]]) + next_batch = next_batch.filter([next_batch.requests[0].id]) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -323,7 +323,7 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None @@ -333,7 +333,7 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter([next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[1].id]) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 09df70d2..81a5e75e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -53,10 +53,10 @@ class CausalLMBatch(Batch): # Past metadata keys_head_dim_last: bool = True - def to_pb(self) -> generate_pb2.Batch: - return generate_pb2.Batch( + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( id=self.batch_id, - requests=self.requests, + request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) @@ -143,16 +143,17 @@ class CausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: - if len(requests) == 0: + def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") - if len(requests) == len(self): + if len(request_ids) == len(self): return self keep_indices = [] # New values after filtering requests_idx_mapping = {} + requests = [] input_lengths = [] prefix_offsets = [] read_offsets = [] @@ -165,11 +166,12 @@ class CausalLMBatch(Batch): total_remaining_decode_tokens = 0 new_padding_right_offset = 0 - for i, r in enumerate(requests): - idx = self.requests_idx_mapping[r.id] - requests_idx_mapping[r.id] = i + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i keep_indices.append(idx) + requests.append(self.requests[idx]) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) all_input_ids.append(self.all_input_ids[idx]) @@ -220,7 +222,7 @@ class CausalLMBatch(Batch): layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values - max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens self.requests = requests self.requests_idx_mapping = requests_idx_mapping diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d6e73ad8..baa6cd7f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int - def to_pb(self) -> generate_pb2.Batch: - return generate_pb2.Batch( + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( id=self.batch_id, - requests=self.requests, + request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) @@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch": - if len(requests) == 0: + def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same - if len(requests) == len(self): + if len(request_ids) == len(self): return self - single_request = len(requests) == 1 + single_request = len(request_ids) == 1 # Cumulative length cumulative_length = 0 @@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} - input_ids = self.input_ids.new_empty(len(requests)) - position_ids = self.position_ids.new_empty(len(requests)) + input_ids = self.input_ids.new_empty(len(request_ids)) + position_ids = self.position_ids.new_empty(len(request_ids)) # Create on CPU to only move to GPU once instead of at every copy - cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32) + cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) cu_seqlens_q = torch.arange( - 0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 + 0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 ) max_seqlen = 0 past_key_values = [] + requests = [] all_input_ids = [] all_input_ids_tensor = [] @@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch): max_tokens = 0 - for i, r in enumerate(requests): - idx = self.requests_idx_mapping[r.id] - requests_idx_mapping[r.id] = i + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + + requests.append(self.requests[idx]) # Get length request_input_length = self.input_lengths[idx] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index a1a39fd4..2abb87ae 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int - def to_pb(self) -> generate_pb2.Batch: - """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" - return generate_pb2.Batch( + def to_pb(self) -> generate_pb2.CachedBatch: + """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf""" + return generate_pb2.CachedBatch( id=self.batch_id, - requests=self.requests, + request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) @@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter( - self, requests: List[generate_pb2.Request] - ) -> Optional["Seq2SeqLMBatch"]: - if len(requests) == 0: + def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") - if len(requests) == len(self): + if len(request_ids) == len(self): return self keep_indices = [] # New values after filtering requests_idx_mapping = {} + requests = [] input_lengths = [] decoder_input_lengths = [] prefix_offsets = [] @@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch): total_remaining_decode_tokens = 0 - for i, r in enumerate(requests): - idx = self.requests_idx_mapping[r.id] - requests_idx_mapping[r.id] = i + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i keep_indices.append(idx) + requests.append(self.requests[idx]) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch): layer[3] = layer[3][keep_indices, :, -max_input_length:] max_tokens = ( - len(requests) * (max_input_length + max_decoder_input_length) + len(request_ids) * (max_input_length + max_decoder_input_length) + remaining_decode_tokens ) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 8a5b82f7..66a8c212 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason class Batch(ABC): @abstractmethod - def to_pb(self) -> generate_pb2.Batch: + def to_pb(self) -> generate_pb2.CachedBatch: raise NotImplementedError @classmethod @@ -26,7 +26,7 @@ class Batch(ABC): raise NotImplementedError @abstractmethod - def filter(self, requests: List[generate_pb2.Request]) -> "Batch": + def filter(self, request_ids: List[int]) -> "Batch": raise NotImplementedError @classmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7ca5054e..e47fd049 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache.delete(request.id) else: self.cache.clear() - if torch.cuda.is_available(): - torch.cuda.empty_cache() return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.keep_requests) + filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())