diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 5191f8dd..8280795d 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -218,8 +218,13 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, @@ -237,11 +242,7 @@ impl Client { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { - batch: None, - batches, - }) - .inject_context(); + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 8872f8bd..39e99776 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -134,11 +134,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -256,7 +257,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index 086fc6dc..bc264138 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -36,18 +36,14 @@ impl BackendV2 { speculate: u32, ) -> Self { // Infer shared state - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) - } else { - Attention::Paged - }; - let block_size = if attention == Attention::FlashDecoding { - 256 - } else { - 16 + let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string()); + let block_size = match attention.as_str() { + "flashinfer" => 1, + "flashdecoding" => 256, + "paged" => 16, + _ => unreachable!(), }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 84152ff8..a5c0f512 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -1,12 +1,14 @@ -use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; /// Batching and inference logic +use crate::client::{ + Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient, +}; use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -31,32 +33,22 @@ impl BackendV3 { max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - support_chunking: bool, + shard_info: InfoResponse, ) -> Self { - if support_chunking { + if shard_info.support_chunking { tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); } - let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); - let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); - let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); - - let attention: Attention = attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); - let block_size = attention.block_size(); + let block_size = shard_info.block_size; let queue = Queue::new( - requires_padding, + shard_info.requires_padding, block_size, - prefix_caching, - window_size, - speculate, + shard_info.use_prefix_caching, + shard_info.window_size, + shard_info.speculate, max_batch_total_tokens, - support_chunking, + shard_info.support_chunking, ); let batching_task_notifier = Arc::new(Notify::new()); @@ -68,7 +60,7 @@ impl BackendV3 { max_batch_total_tokens, max_waiting_tokens, max_batch_size, - support_chunking, + shard_info.support_chunking, queue.clone(), batching_task_notifier.clone(), )); @@ -154,7 +146,7 @@ pub(crate) async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + let mut cached_batch = prefill(&mut client, batch, None, &mut entries) .instrument(span) .await; let mut waiting_tokens = 1; @@ -175,7 +167,8 @@ pub(crate) async fn batching_task( let (min_size, max_size, prefill_token_budget) = if support_chunking { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget - let prefill_token_budget = max_batch_prefill_tokens - current_tokens; + let prefill_token_budget = + max_batch_prefill_tokens.saturating_sub(current_tokens); // We can ignore min_size and max_size // Models than rely on max_size cannot support chunking // Regarding min_size, chunking allow us to consistently run at the compute @@ -199,10 +192,8 @@ pub(crate) async fn batching_task( (min_size, max_size, max_batch_prefill_tokens) }; - let mut additional_batch = None; - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue + if let Some((new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, prefill_token_budget, token_budget) .await { @@ -218,11 +209,11 @@ pub(crate) async fn batching_task( }; counter.increment(1); } - - if support_chunking { - entries.extend(new_entries); - additional_batch = Some(new_batch); + let cached_batch = if support_chunking { + // Concat current batch to the new one + batches.pop() } else { + // Request are waiting only if we don't support chunking entries.iter_mut().for_each(|(_, entry)| { // Create a new span to add the info that this entry is waiting // because a new batch is being computed @@ -233,18 +224,23 @@ pub(crate) async fn batching_task( // Update entry entry.temp_span = Some(entry_waiting_span); }); + None + }; + entries.extend(new_entries); - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + prefill(&mut client, new_batch, cached_batch, &mut entries) .instrument(span) .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + batches.push(new_cached_batch); + } else if support_chunking { + // New cached batch is empty, no work left + break; } } @@ -262,7 +258,7 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, additional_batch, batches, &mut entries) + cached_batch = decode(&mut client, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -277,13 +273,14 @@ pub(crate) async fn batching_task( async fn prefill( client: &mut ShardedClient, batch: Batch, + cached_batch: Option, entries: &mut IntMap, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - match client.prefill(batch).await { + match client.prefill(batch, cached_batch).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries @@ -292,6 +289,10 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") .record(timings.forward.as_secs_f64()); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") @@ -316,7 +317,6 @@ async fn prefill( #[instrument(skip_all)] async fn decode( client: &mut ShardedClient, - batch: Option, batches: Vec, entries: &mut IntMap, ) -> Option { @@ -324,7 +324,7 @@ async fn decode( let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - match client.decode(batch, batches).await { + match client.decode(batches).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index ab93db9b..804c77d4 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -218,13 +218,23 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, response.batch, - PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + PrefillTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), )) } @@ -235,10 +245,9 @@ impl Client { #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, - batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context(); + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); Ok(( response.generations, @@ -254,14 +263,16 @@ impl Client { } pub struct PrefillTimings { + pub concat: Option, pub forward: Duration, pub decode: Duration, pub total: Duration, } impl PrefillTimings { - fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { Self { + concat: concat_ns.map(Duration::from_nanos), forward: Duration::from_nanos(forward_ns), decode: Duration::from_nanos(decode_ns), total: Duration::from_nanos(total_ns), diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 8af4b26f..e25bf71e 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -135,11 +135,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -167,13 +168,12 @@ impl ShardedClient { #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] pub async fn decode( &mut self, - batch: Option, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.decode(batch.clone(), batches.clone()))) + .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, DecodeTimings)>> = @@ -246,7 +246,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 0a7ef223..7daf9eae 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -31,6 +31,12 @@ pub struct BackendInfo { pub max_batch_size: Option, #[schema(example = "false")] pub support_chunking: bool, + #[schema(example = "false")] + pub prefix_caching: bool, + #[schema(example = "flashinfer")] + pub attention_impl: String, + #[schema(example = "1")] + pub block_size: u32, } #[allow(clippy::too_many_arguments)] @@ -113,6 +119,9 @@ pub async fn connect_backend( model_dtype: shard_info.dtype.clone(), speculate: shard_info.speculate as usize, support_chunking: shard_info.support_chunking, + prefix_caching: shard_info.use_prefix_caching, + attention_impl: shard_info.attention_impl.clone(), + block_size: shard_info.block_size, }; let backend = BackendV3::new( @@ -122,10 +131,7 @@ pub async fn connect_backend( max_batch_total_tokens, max_waiting_tokens, max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - shard_info.support_chunking, + shard_info, ); tracing::info!("Using backend V3"); diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 7db0aba3..a07c725c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -89,6 +89,10 @@ impl Queue { prefill_token_budget: u32, token_budget: u32, ) -> Option { + if prefill_token_budget == 0 || token_budget == 0 { + return None; + }; + // Create response channel let (response_sender, response_receiver) = oneshot::channel(); // Send next batch command to the background task managing the state diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index fff221ef..43a84e70 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -174,7 +174,7 @@ async fn prefill( // Run prefill let start_time = Instant::now(); - let (_, decode_batch, _) = client.prefill(batch.clone()).await?; + let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?; // Get latency let latency = start_time.elapsed(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 15a93ac9..e4dfefef 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -35,6 +35,9 @@ message InfoResponse { optional uint32 window_size = 4; uint32 speculate = 5; bool support_chunking = 6; + bool use_prefix_caching = 7; + string attention_impl = 8; + uint32 block_size = 9; } /// Empty request @@ -225,6 +228,8 @@ message FilterBatchResponse { message PrefillRequest { /// Batch Batch batch = 1; + /// Optional cached batch + CachedBatch cached_batch = 2; } message PrefillResponse { @@ -238,13 +243,13 @@ message PrefillResponse { uint64 decode_ns = 4; /// Total elapsed time in nanoseconds uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message DecodeRequest { /// Cached batches repeated CachedBatch batches = 1; - /// Optional Batch - optional Batch batch = 2; } message DecodeResponse { diff --git a/router/src/lib.rs b/router/src/lib.rs index b29c9395..fdbd931e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,45 +18,6 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; -#[derive(PartialEq)] -pub enum Attention { - Paged, - FlashDecoding, - FlashInfer, -} - -impl Attention { - pub fn block_size(&self) -> u32 { - match self { - Attention::FlashDecoding => 256, - Attention::FlashInfer => 1, - Attention::Paged => 16, - } - } -} - -#[derive(Debug)] -pub struct ParseError; - -impl std::fmt::Display for ParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Cannot parse attention value") - } -} -impl std::error::Error for ParseError {} - -impl std::str::FromStr for Attention { - type Err = ParseError; - fn from_str(s: &str) -> Result { - match s { - "paged" => Ok(Attention::Paged), - "flashdecoding" => Ok(Attention::FlashDecoding), - "flashinfer" => Ok(Attention::FlashInfer), - _ => Err(ParseError), - } - } -} - /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 1378f590..0e5c0163 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -76,7 +76,7 @@ class CausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, - current_tokens=len(self), + current_tokens=len(self.input_ids), ) @classmethod diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b283a5fb..65552ff7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -171,7 +171,7 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] # Will be set by `generate_token` and reset after each prefill forward - prefill_tokens: List[Optional[Tokens]] + prefill_logprob_tokens: List[Optional[Tokens]] # Prefixes prefix_ids: List[List[int]] @@ -290,8 +290,7 @@ class FlashCausalLMBatch(Batch): prefix_length <= prompt_length ), f"Prefix {prefix_length} vs input {prompt_length}" if prefix_length == prompt_length: - assert prefix_length > 0 - prefix_length -= 1 + assert False, "unreachable" if prefix_length + postfix_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 @@ -303,7 +302,9 @@ class FlashCausalLMBatch(Batch): prefix_length : prefix_length + postfix_length ] - postfix_length = len(postfix_ids) + assert ( + len(postfix_ids) == postfix_length + ), "Rust and Python tokenizers are not aligned" postfix_lengths.append(postfix_length) prefix_offsets.append(prompt_length - 5) @@ -394,7 +395,7 @@ class FlashCausalLMBatch(Batch): max_current_length=max_current_length, prefilling=True, prefilling_mask=[True] * len(pb.requests), - prefill_tokens=[None] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), postfix_lengths=postfix_lengths, prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, @@ -475,7 +476,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] prefilling_mask = [] - prefill_tokens = [] + prefill_logprob_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -518,7 +519,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) - prefill_tokens.append(self.prefill_tokens[idx]) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) @@ -611,7 +612,7 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - prefill_tokens=prefill_tokens, + prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -726,7 +727,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] - prefill_tokens = [] + prefill_logprob_tokens = [] next_token_chooser_parameters = [] fsm_grammar_states = [] @@ -814,7 +815,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) - prefill_tokens.extend(batch.prefill_tokens) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) @@ -869,7 +870,7 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - prefill_tokens=prefill_tokens, + prefill_logprob_tokens=prefill_logprob_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -1769,9 +1770,10 @@ class FlashCausalLM(Model): if get_support_chunking(): next_prefilling_mask = [] # Budget in tokens for the next batch - # We remove len(batch) to always have enough space for at least a single decode - # for the remaining requests - batch_budget = get_max_prefill_tokens() - len(batch) + # We remove (len(batch) - 1) to always have enough space for at least a single decode + # for the remaining requests -1 because the first request does not need to be removed from the budget + # (ex: you have one request in the batch, you want it to take the full budget not budget -1) + batch_budget = get_max_prefill_tokens() - (len(batch) - 1) # We reverse to prioritize older requests # zip() is not reversible so reverse the underlying lists instead for prefix_length, postfix_length, prompt_length in zip( @@ -1790,6 +1792,7 @@ class FlashCausalLM(Model): finished_prefilling = False next_prefilling_mask.append(True) else: + # FIXME: use true number of accepted tokens instead of 1 # Since speculation will be turned off, this is always true next_chunk_length = 1 next_prefilling_mask.append(False) @@ -1807,14 +1810,7 @@ class FlashCausalLM(Model): batch.prefilling = not finished_prefilling batch.prefilling_mask = next_prefilling_mask - # Turn off speculative if some requests are still prefilling - # It makes the logic easier to follow - if prefill and not finished_prefilling: - speculate = 0 - speculative_logits = None - else: - speculate = get_speculate() - + speculate = get_speculate() ( next_input_ids, next_token_logprobs, @@ -1914,7 +1910,7 @@ class FlashCausalLM(Model): ] = next_input_ids[index] index += 1 - cumulative_length += postfix_length + cumulative_length += postfix_length # Update values # These values can be updated without a GPU -> CPU sync @@ -2045,18 +2041,18 @@ class FlashCausalLM(Model): # this state to be stable if request.id % self.world_size == self.rank: # Prefill - if prefill and request.prefill_logprobs: + if request_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - request_prefill_tokens = batch.prefill_tokens[i] - request_prefill_logprobs = prefill_logprobs[ out_start_index : out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] - if request_prefill_tokens is None: + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] * ( len(prefix_ids) + 1 @@ -2069,18 +2065,20 @@ class FlashCausalLM(Model): skip_special_tokens=False, ) - prefill_tokens = Tokens( + prefill_logprob_tokens = Tokens( prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) - if request_prefill_tokens is not None: - prefill_tokens = request_prefill_tokens + prefill_tokens + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens + ) - batch.prefill_tokens[i] = prefill_tokens + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - batch.prefill_tokens[i] = None + batch.prefill_logprob_tokens[i] = None # If it is, the tokens we decoded should be ignored if request_prefilling: @@ -2178,7 +2176,7 @@ class FlashCausalLM(Model): generation = Generation( request.id, - batch.prefill_tokens[i], + batch.prefill_logprob_tokens[i], Tokens( _next_token_ids, _next_token_logprobs, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 02f3dbf9..05d36ba3 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -7,6 +7,7 @@ from collections import defaultdict from transformers import PreTrainedTokenizerBase from loguru import logger +from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.log import log_master from text_generation_server.utils.prefill_chunking import set_support_chunking @@ -94,6 +95,9 @@ class Model(ABC): window_size=self.sliding_window, speculate=self.speculate, support_chunking=self.support_chunking, + use_prefix_caching=PREFIX_CACHING, + attention_impl=ATTENTION, + block_size=BLOCK_SIZE, ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e2d7aa4d..0a1d0824 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, - current_tokens=len(self), + current_tokens=len(self.input_ids), ) @classmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d89df966..cc7979d4 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -153,6 +153,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) + concat_ns = None + if self.model.support_chunking: + if request.HasField("cached_batch"): + cached_batch = self.cache.pop(request.cached_batch.id) + if cached_batch is None: + raise ValueError( + f"Batch ID {request.cached_batch.id} not found in cache." + ) + start_concat = time.time_ns() + batch = self.model.batch_type.concatenate([batch, cached_batch]) + concat_ns = time.time_ns() - start_concat + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) @@ -162,6 +174,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): forward_ns=timings[0], decode_ns=timings[1], total_ns=time.time_ns() - start, + concat_ns=concat_ns, ) async def Decode(self, request, context): @@ -179,16 +192,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) == 0: raise ValueError("All batches are empty") - if self.model.support_chunking: - if request.HasField("batch"): - batch = self.model.batch_type.from_pb( - request.batch, - self.model.tokenizer, - self.model.dtype, - self.model.device, - ) - batches.append(batch) - if len(batches) > 1: start_concat = time.time_ns() batch = self.model.batch_type.concatenate(batches)