From 1cc86930a682f1a1f63434e798220bf2ebe7ca21 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 17:01:06 +0200 Subject: [PATCH] wip --- proto/v3/generate.proto | 33 +++----- router/client/src/v3/client.rs | 4 +- router/client/src/v3/mod.rs | 2 +- router/client/src/v3/sharded_client.rs | 6 +- router/src/infer/mod.rs | 3 + router/src/infer/v3/block_allocator.rs | 8 ++ router/src/infer/v3/queue.rs | 2 +- router/src/infer/v3/scheduler.rs | 82 +++++++++++-------- router/src/lib.rs | 3 - router/src/server.rs | 1 + .../models/causal_lm.py | 8 +- .../models/flash_causal_lm.py | 67 +++++++-------- .../models/idefics_causal_lm.py | 8 +- server/text_generation_server/models/mamba.py | 8 +- .../models/seq2seq_lm.py | 6 +- server/text_generation_server/models/types.py | 2 +- .../models/vlm_causal_lm.py | 6 +- server/text_generation_server/server.py | 2 +- 18 files changed, 138 insertions(+), 113 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index d57fbbad..192cd111 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -17,8 +17,6 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Update batch - rpc Update(UpdateRequest) returns (UpdateResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } @@ -204,11 +202,20 @@ message Generation { uint32 current_length = 6; } +message UpdatedRequest { + /// Request ID + uint64 id = 1; + /// Paged attention blocks + repeated uint32 blocks = 2; + /// Paged attention slots + repeated uint32 slots = 3; +} + message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated uint64 request_ids = 2; + repeated UpdatedRequest updated_requests = 2; } message FilterBatchResponse { @@ -255,26 +262,6 @@ message DecodeResponse { optional uint64 concat_ns = 6; } -message ExtendedRequest { - /// Request ID - uint64 request_id = 1; - /// Paged attention blocks to add - repeated uint32 blocks = 2; - /// Paged attention slots to add - repeated uint32 slots = 3; -} - -message UpdateRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to update - repeated ExtendedRequest extend_requests = 2; - /// Requests to terminate - repeated uint64 terminated_request_ids = 3; -} - -message UpdateResponse {} - message WarmupRequest { /// Batch to warmup on Batch batch = 1; diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb..8cefd313 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -90,11 +90,11 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, + updated_requests: Vec, ) -> Result> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - request_ids, + updated_requests, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index 4a1296a2..df2bb380 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -8,6 +8,6 @@ pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, + StoppingCriteriaParameters, Tokens, UpdatedRequest, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55..a066176c 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -10,7 +10,7 @@ use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; use v3::{ Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest, }; #[derive(Debug, Clone)] @@ -84,12 +84,12 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, + updated_requests: Vec, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 20630c1b..3b61e466 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -506,6 +506,8 @@ pub enum InferError { TemplateError(#[from] minijinja::Error), #[error("Tool error: {0}")] ToolError(String), + #[error("Request could not be re-allocated: out of pages")] + OutOfPages, } impl InferError { @@ -517,6 +519,7 @@ impl InferError { InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", InferError::ToolError(_) => "tool_error", + InferError::OutOfPages => "out_of_pages", } } } diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 7467fd85..811efb26 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -8,6 +8,12 @@ pub(crate) struct BlockAllocation { block_allocator: BlockAllocator, } +impl BlockAllocation { + pub(crate) fn len(&self) -> usize { + self.slots.len() + } +} + impl Drop for BlockAllocation { fn drop(&mut self) { self.block_allocator.free(self.blocks.clone()) @@ -83,6 +89,8 @@ async fn block_allocator_task( tokens, response_sender, } => { + // let tokens = 16; + // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match window_size { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 15226794..1ac06ae9 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -34,7 +34,7 @@ pub(crate) struct Entry { /// Block Allocation pub block_allocation: Option, /// Current length (in tokens) of the request (prompt tokens + generated_tokens) - pub current_length: u32 + pub current_length: u32, } /// Request Queue diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index bf52e69f..faa899ec 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -10,7 +10,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -288,7 +288,7 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); - filter_update_allocations(client, entries).await; + filter_update_allocations(entries).await; // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; @@ -323,7 +323,7 @@ async fn filter_batch( next_batch: Option, entries: &IntMap, ) -> Option { - let mut batch = next_batch?; + let batch = next_batch?; // No need to filter if batch.size as usize == entries.len() { @@ -331,11 +331,7 @@ async fn filter_batch( } let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { + if entries.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache @@ -344,8 +340,24 @@ async fn filter_batch( None } else { // Filter Python shard cache + let updated_requests = entries + .iter() + .map(|(request_id, entry)| { + let (blocks, slots) = entry + .block_allocation + .as_ref() + .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) + .unwrap_or((Vec::new(), Vec::new())); + UpdatedRequest { + id: *request_id, + blocks, + slots, + } + }) + .collect(); + // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() + client.filter_batch(id, updated_requests).await.unwrap() } } @@ -379,32 +391,36 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - // let mut extend_entries = Vec::with_capacity(entries.len()); - // let mut finish_entries = Vec::with_capacity(entries.len()); +async fn filter_update_allocations(entries: &mut IntMap) { + entries.retain(|request_id, entry| { + if entry.block_allocation.is_none() { + return true; + } - // for (request_id, entry) in entries.into_iter() { - // tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length); - // - // if let Some(block_allocation) = &mut entry.block_allocation { - // tracing::info!("Allocation {:?}", block_allocation); - // - // if entry.current_length > block_allocation.allocated_tokens { - // // We need to add new blocks to this entry - // let remaining_tokens = block_allocation.total_tokens - entry.current_length; - // match block_allocation.extend(remaining_tokens).await { - // true => { - // - // }, - // false => { - // - // } - // } - // } - // } - // } + // We can unwrap since we already validated above that block_allocation is not None + let mut block_allocation = entry.block_allocation.as_ref().unwrap(); + + // Nothing to update + if entry.current_length <= block_allocation.len() as u32 { + return true; + } + + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::OutOfPages; + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + + false + }); } /// Send responses through the `entry` response channel diff --git a/router/src/lib.rs b/router/src/lib.rs index 52c5aa46..b6902c49 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,8 +1085,6 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, - #[schema(rename = "out_of_pages")] - OutOfPages } impl std::fmt::Display for FinishReason { @@ -1095,7 +1093,6 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), - FinishReason::OutOfPages => write!(f, "out_of_pages"), } } } diff --git a/router/src/server.rs b/router/src/server.rs index 30479b0e..9df33739 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1859,6 +1859,7 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS, }; ( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 2fe0f56e..50a25a50 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -158,7 +158,11 @@ class CausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["CausalLMBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -746,7 +750,7 @@ class CausalLM(Model): ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index da5fa9db..c4c1cf9a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -82,14 +82,10 @@ class FlashCausalLMBatch(Batch): # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor - # list of length b of list of length s_i // block_size - block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor - # list of length b of list of length s_i - slots: List[List[int]] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots_tensor: torch.Tensor + slots: torch.Tensor max_seqlen: int @@ -183,7 +179,6 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 block_tables = [] - slots = [] flat_slots = [] # Parse batch @@ -253,7 +248,6 @@ class FlashCausalLMBatch(Batch): len(flat_slots) + input_length, dtype=torch.int64, ) - slots.append(request_slots) flat_slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -353,7 +347,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) - slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -370,10 +364,8 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill=cu_seqlen_prefill, prefill_cache_indices=prefill_cache_indices, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -405,11 +397,13 @@ class FlashCausalLMBatch(Batch): return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - if len(request_ids) == 0: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["FlashCausalLMBatch"]: + if len(updated_requests) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same - if len(request_ids) == len(self): + if len(updated_requests) == len(self): return self device = self.input_ids.device @@ -425,7 +419,6 @@ class FlashCausalLMBatch(Batch): requests = [] block_tables = [] - slots = [] flat_slots = [] all_input_ids = [] @@ -439,7 +432,9 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - for i, request_id in enumerate(request_ids): + for i, request in enumerate(updated_requests): + request_id = request.id + idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i @@ -461,13 +456,12 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) - request_block_table = self.block_tables[idx] + request_block_table = request.blocks num_blocks += len(request_block_table) block_tables.append(request_block_table) # List of slots allocated for this request - request_slots = self.slots[idx] - slots.append(request_slots) + request_slots = request.slots # Index slot_indices.append(len(flat_slots) + request_input_length - 1) @@ -479,7 +473,6 @@ class FlashCausalLMBatch(Batch): input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] - block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] @@ -487,10 +480,20 @@ class FlashCausalLMBatch(Batch): self.speculative_ids[indices] if self.speculative_ids is not None else None ) + # Create block_tables_tensor on CPU + block_tables_tensor = torch.zeros( + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + # Allocate on GPU - slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) + # Move to GPU + block_tables_tensor = block_tables_tensor.to(device) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -500,10 +503,8 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -538,7 +539,7 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots_tensor) + total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 @@ -561,7 +562,7 @@ class FlashCausalLMBatch(Batch): input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots_tensor = batches[0].slots_tensor.new_empty(total_slots) + slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -576,8 +577,6 @@ class FlashCausalLMBatch(Batch): total_batch_size, ) - slots = [] - block_tables = [] all_input_ids = [] input_lengths = [] @@ -606,7 +605,7 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots_tensor) + slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids @@ -614,7 +613,7 @@ class FlashCausalLMBatch(Batch): slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor + slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -624,8 +623,6 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - slots.extend(batch.slots) - block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -640,7 +637,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots_tensor) + cumulative_slots += len(batch.slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -665,10 +662,8 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -969,7 +964,7 @@ class FlashCausalLM(Model): cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1008,7 +1003,7 @@ class FlashCausalLM(Model): cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1350,7 +1345,7 @@ class FlashCausalLM(Model): ), generated_text, top_tokens, - input_length + n_accepted_ids + input_length + n_accepted_ids, ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 44b21899..fd70ae5d 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -214,7 +214,11 @@ class IdeficsCausalLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["IdeficsCausalLMBatch"]: + request_ids = [r.id for r in updated_requests] + # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") @@ -829,7 +833,7 @@ class IdeficsCausalLM(Model): ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 8182eb46..c8066aec 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -195,7 +195,11 @@ class MambaBatch(Batch): max_tokens=max_tokens, ) - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["MambaBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -775,7 +779,7 @@ class Mamba(Model): ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 74ea2dab..1e4f7c2e 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -166,7 +166,11 @@ class Seq2SeqLMBatch(Batch): ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["Seq2SeqLMBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 1c7a157a..50c14862 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -28,7 +28,7 @@ class Batch(ABC): raise NotImplementedError @abstractmethod - def filter(self, request_ids: List[int]) -> "Batch": + def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch": raise NotImplementedError @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index b1ccd140..bc51e732 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -122,8 +122,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["VlmCausalLMBatch"]: + batch = super().filter(updated_requests) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 569b6925..a66c19a0 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,7 +83,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids) + filtered_batch = batch.filter(request.updated_requests) self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())