diff --git a/proto/generate.proto b/proto/generate.proto index cc14cbf8..98b3d026 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -132,7 +132,7 @@ message PrefillResponse { /// Generation repeated Generation generations = 1; /// Next batch (cached) - optional Batch batch = 2; + Batch batch = 2; } message DecodeRequest { @@ -144,5 +144,5 @@ message DecodeResponse { /// Decodes repeated Generation generations = 1; /// Next batch (cached) - optional Batch batch = 2; + Batch batch = 2; } \ No newline at end of file diff --git a/router/src/infer.rs b/router/src/infer.rs index 5a4375ae..9618264d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -7,9 +7,8 @@ use futures::future::try_join_all; use futures::stream::StreamExt; use nohash_hasher::IntMap; use std::sync::Arc; -use text_generation_client::{ - Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, -}; +use flume::SendError; +use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient}; use thiserror::Error; use tokio::sync::{Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; @@ -339,7 +338,21 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch)) => { - send_generations(generations, entries); + filter_send_generations(generations, entries); + + let next_batch = { + let mut batch = next_batch.expect("next_batch is None. This is a bug."); + + batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect(); + let size = batch.requests.len(); + if size == 0 { + let _ = client.clear_cache(Some(batch.id)).await; + return None; + } + batch.size = size as u32; + Some(batch) + }; + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); next_batch @@ -361,17 +374,35 @@ async fn decode( entries: &mut IntMap, ) -> Option { let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); match client.decode(batches).await { Ok((generations, next_batch)) => { - send_generations(generations, entries); + filter_send_generations(generations, entries); + + let next_batch = { + let mut batch = next_batch.expect("next_batch is None. This is a bug."); + + batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect(); + let size = batch.requests.len(); + if size == 0 { + let _ = client.clear_cache(Some(batch.id)).await; + return None; + } + batch.size = size as u32; + Some(batch) + }; + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); next_batch } // If we have an error, we discard the whole batch Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); None @@ -398,64 +429,66 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { } /// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries #[instrument(skip_all)] -fn send_generations(generations: Vec, entries: &mut IntMap) { +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { generations.into_iter().for_each(|generation| { + let id = generation.request_id; // Get entry // We can `expect` here as the request id should always be in the entries let entry = entries - .get(&generation.request_id) + .get(&id) .expect("ID not found in entries. This is a bug."); // Create and enter a span to link this function back to the entry - let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Send message - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens))) - .unwrap_or(()); - } - - // Create last Token - let token = Token { - id: generation.token_id, - text: generation.token_text, - logprob: generation.token_logprob, - special: generation.token_is_special, - }; - - if let Some(generated_text) = generation.generated_text { - // Remove entry as this is the last message - // We can `expect` here as the request id should always be in the entries - let entry = entries - .remove(&generation.request_id) - .expect("ID not found in entries. This is a bug."); - - // Send message - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Ok(InferStreamResponse::End { - token, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - })) - .unwrap_or(()); - } else { - // Send message - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Ok(InferStreamResponse::Token(token))) - .unwrap_or(()); + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation back to infer task + // If the receive an error from the Flume channel, we need to stop generating for this + // request hence why we unwrap_or(true) + let stopped = send_generation(generation, entry).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); } }); } +fn send_generation(generation: Generation, entry: &Entry) -> Result>> { + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Send message + entry.response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let token = Token { + id: generation.token_id, + text: generation.token_text, + logprob: generation.token_logprob, + special: generation.token_is_special, + }; + + if let Some(generated_text) = generation.generated_text { + // Generation has ended + stopped = true; + // Send message + entry.response_tx + .send(Ok(InferStreamResponse::End { + token, + generated_text, + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } else { + // Send message + entry.response_tx + .send(Ok(InferStreamResponse::Token(token))) + ?; + } + Ok(stopped) +} + #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6347b1a5..2ed5cf53 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -3,7 +3,7 @@ import torch from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type +from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__) class CausalLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] # Decoder values input_ids: torch.Tensor @@ -42,7 +43,6 @@ class CausalLMBatch(Batch): stopping_criterias: List[StoppingCriteria] # Metadata used for padding - size: int max_input_length: int padding_right_offset: int @@ -53,26 +53,28 @@ class CausalLMBatch(Batch): return generate_pb2.Batch( id=self.batch_id, requests=self.requests, - size=self.size, + size=len(self), ) @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] offsets = [] token_offsets = [] + requests_idx_mapping = {} # Parse batch max_truncation = 0 padding_right_offset = 0 - for r in pb.requests: + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i inputs.append(r.inputs) offsets.append(None) token_offsets.append(None) @@ -108,26 +110,88 @@ class CausalLMBatch(Batch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) return cls( batch_id=pb.id, requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - all_input_ids=all_input_ids, + all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), offsets=offsets, token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - size=pb.size, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, ) + @tracer.start_as_current_span("filter") + def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: + if len(requests) == 0: + raise ValueError("Batch must have at least one request") + if len(requests) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + input_lengths = [] + offsets = [] + token_offsets = [] + all_input_ids = [] + max_input_length = 0 + + for i, r in enumerate(requests): + idx = self.requests_idx_mapping[r.id] + keep_indices.append(idx) + requests_idx_mapping[r.id] = i + + offsets.append(self.offsets[idx]) + token_offsets.append(self.token_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max( + max_input_length, request_input_length + ) + + # Replace metadata + self.requests_idx_mapping = requests_idx_mapping + self.input_lengths = input_lengths + self.offsets = offsets + self.token_offsets = token_offsets + self.all_input_ids = all_input_ids + self.max_input_length = max_input_length + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + self.input_ids = self.input_ids[keep_indices] + self.attention_mask = self.attention_mask[keep_indices] + self.position_ids = self.position_ids[keep_indices] + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + self.past_key_values = [ + [ + t.view(len(self), -1, *t.shape[-2:])[keep_indices] + for t in layer + ] + for layer in self.past_key_values + ] + self.requests = [self.requests[i] for i in keep_indices] + self.next_token_choosers = [ + self.next_token_choosers[i] for i in keep_indices + ] + self.stopping_criterias = [ + self.stopping_criterias[i] for i in keep_indices + ] + + return self + @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": @@ -136,12 +200,13 @@ class CausalLMBatch(Batch): max_input_length = 0 padding_right_offset = 0 for batch in batches: - total_batch_size += batch.size + total_batch_size += len(batch) max_input_length = max(max_input_length, batch.max_input_length) padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes requests = [] + requests_idx_mapping = {} input_lengths = [] offsets = [] token_offsets = [] @@ -167,8 +232,14 @@ class CausalLMBatch(Batch): next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + if i == 0: + requests_idx_mapping = requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + # Slicing end index for this batch - end_index = start_index + batch.size + end_index = start_index + len(batch) # We only concatenate batches that did at least one step if batch.past_key_values is None: @@ -192,17 +263,17 @@ class CausalLMBatch(Batch): # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset ) attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, + start_index:end_index, + left_offset:-padding_right_offset, ] = batch.attention_mask[ :, - batch_left_offset : -batch.padding_right_offset, - ] + batch_left_offset: -batch.padding_right_offset, + ] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -216,8 +287,8 @@ class CausalLMBatch(Batch): # Shenanigans to get dimensions because BLOOM outputs a past with a different shape # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:]) - past_values = past_values.view(batch.size, -1, *past_values.shape[-2:]) + past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(batch), -1, *past_values.shape[-2:]) _, num_heads, padded_sequence_length, head_dim = past_values.shape @@ -248,28 +319,29 @@ class CausalLMBatch(Batch): # We slice the past keys and values to remove the padding from previous batches if batch.keys_head_dim_last: past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_input_length - 1) :, - :, - ] = past_keys[:, :, -(batch.max_input_length - 1) :, :] + start_index:end_index, + :, + -(batch.max_input_length - 1):, + :, + ] = past_keys[:, :, -(batch.max_input_length - 1):, :] else: past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_input_length - 1) :, - ] = past_keys[:, :, :, -(batch.max_input_length - 1) :] + start_index:end_index, + :, + :, + -(batch.max_input_length - 1):, + ] = past_keys[:, :, :, -(batch.max_input_length - 1):] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_input_length - 1) :, : - ] = past_values[:, :, -(batch.max_input_length - 1) :, :] + start_index:end_index, :, -(batch.max_input_length - 1):, : + ] = past_values[:, :, -(batch.max_input_length - 1):, :] - start_index += batch.size + start_index += len(batch) return cls( batch_id=batches[0].batch_id, requests=requests, + requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -280,7 +352,6 @@ class CausalLMBatch(Batch): token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - size=total_batch_size, max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, @@ -292,11 +363,11 @@ class CausalLMBatch(Batch): class CausalLM(Model): def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: bool = False, - decode_buffer: int = 3, + self, + model_id: str, + revision: Optional[str] = None, + quantize: bool = False, + decode_buffer: int = 3, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -338,7 +409,7 @@ class CausalLM(Model): ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -352,8 +423,8 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: CausalLMBatch - ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + self, batch: CausalLMBatch + ) -> Tuple[List[Generation], CausalLMBatch]: # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] @@ -364,19 +435,8 @@ class CausalLM(Model): batch.past_key_values, ) - # List of indices to cache - next_batch_keep_indices = [] - # New values for next forward - next_batch_input_lengths = [] - next_batch_offsets = [] - next_batch_token_offsets = [] next_batch_input_ids = [] - next_batch_all_input_ids = [] - - # Metadata - next_batch_size = 0 - next_batch_max_input_length = 0 # Results generations: List[Generation] = [] @@ -395,14 +455,14 @@ class CausalLM(Model): # For each member of the batch for i, ( - request, - input_length, - offset, - token_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, + request, + input_length, + offset, + token_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -429,7 +489,7 @@ class CausalLM(Model): if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + all_input_ids[-stopping_criteria.current_tokens:, 0] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -443,16 +503,6 @@ class CausalLM(Model): else: # Keep request in the batch generated_text = None - next_batch_keep_indices.append(i) - next_batch_input_ids.append(next_token_id) - next_batch_all_input_ids.append(all_input_ids) - next_batch_size += 1 - next_batch_input_lengths.append(new_input_length) - next_batch_offsets.append(offset) - next_batch_token_offsets.append(token_offset) - next_batch_max_input_length = max( - next_batch_max_input_length, new_input_length - ) # Prefill if stopping_criteria.current_tokens == 1: @@ -484,62 +534,25 @@ class CausalLM(Model): generations.append(generation) - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generations, None - - next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) - # If we finished at least one generation, we need to evict the indices of the generations that finished - # from the values of the next batch - if len(next_batch_keep_indices) != len(batch): - # Apply indices to attention mask, past key values and other items that need to be cached - next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] - next_batch_position_ids = batch.position_ids[next_batch_keep_indices] - # Force past to be of dim [batch_size, num_heads, ...] for easy indexing - next_batch_past_key_values = [ - [ - t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices] - for t in layer - ] - for layer in past - ] - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_attention_mask = batch.attention_mask - next_batch_position_ids = batch.position_ids - next_batch_past_key_values = past - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias + # Update values + next_batch_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.offsets[i] = offset + batch.token_offsets[i] = token_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + # Decrease right offset + batch.padding_right_offset -= 1 + # Create input_ids tensor + batch.input_ids = torch.cat(next_batch_input_ids, dim=0) # Update attention_mask as we added a new token to input_ids - next_batch_attention_mask[:, -batch.padding_right_offset] = 1 + batch.attention_mask[:, -batch.padding_right_offset] = 1 # Update position_ids - next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 + batch.position_ids = batch.position_ids[:, -1:] + 1 - next_batch = CausalLMBatch( - batch_id=batch.batch_id, - requests=next_batch_requests, - input_ids=next_batch_input_ids, - attention_mask=next_batch_attention_mask, - position_ids=next_batch_position_ids, - past_key_values=next_batch_past_key_values, - all_input_ids=next_batch_all_input_ids, - input_lengths=next_batch_input_lengths, - offsets=next_batch_offsets, - token_offsets=next_batch_token_offsets, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - size=next_batch_size, - max_input_length=next_batch_max_input_length, - padding_right_offset=batch.padding_right_offset - 1, - keys_head_dim_last=batch.keys_head_dim_last, - ) - return generations, next_batch + # Update past key values + batch.past_key_values = past + + return generations, batch diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 93c3b9db..8a5b82f7 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -25,6 +25,10 @@ class Batch(ABC): ) -> "Batch": raise NotImplementedError + @abstractmethod + def filter(self, requests: List[generate_pb2.Request]) -> "Batch": + raise NotImplementedError + @classmethod @abstractmethod def concatenate(cls, batches: List["Batch"]) -> "Batch": diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 3e3789bf..3caee803 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -60,7 +60,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.cache.pop(batch_pb.id) if batch is None: raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) + batch = batch.filter(batch_pb.requests) + if batch is not None: + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") if len(batches) > 1: batch = self.model.batch_type.concatenate(batches)