This commit is contained in:
OlivierDehaene 2024-06-05 17:01:06 +02:00
parent 18e77a5cc7
commit 1cc86930a6
18 changed files with 138 additions and 113 deletions

View File

@ -17,8 +17,6 @@ service TextGenerationService {
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse); rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Update batch
rpc Update(UpdateRequest) returns (UpdateResponse);
/// Health check /// Health check
rpc Health (HealthRequest) returns (HealthResponse); rpc Health (HealthRequest) returns (HealthResponse);
} }
@ -204,11 +202,20 @@ message Generation {
uint32 current_length = 6; uint32 current_length = 6;
} }
message UpdatedRequest {
/// Request ID
uint64 id = 1;
/// Paged attention blocks
repeated uint32 blocks = 2;
/// Paged attention slots
repeated uint32 slots = 3;
}
message FilterBatchRequest { message FilterBatchRequest {
/// Batch ID /// Batch ID
uint64 batch_id = 1; uint64 batch_id = 1;
/// Requests to keep /// Requests to keep
repeated uint64 request_ids = 2; repeated UpdatedRequest updated_requests = 2;
} }
message FilterBatchResponse { message FilterBatchResponse {
@ -255,26 +262,6 @@ message DecodeResponse {
optional uint64 concat_ns = 6; optional uint64 concat_ns = 6;
} }
message ExtendedRequest {
/// Request ID
uint64 request_id = 1;
/// Paged attention blocks to add
repeated uint32 blocks = 2;
/// Paged attention slots to add
repeated uint32 slots = 3;
}
message UpdateRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to update
repeated ExtendedRequest extend_requests = 2;
/// Requests to terminate
repeated uint64 terminated_request_ids = 3;
}
message UpdateResponse {}
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;

View File

@ -90,11 +90,11 @@ impl Client {
pub async fn filter_batch( pub async fn filter_batch(
&mut self, &mut self,
batch_id: u64, batch_id: u64,
request_ids: Vec<u64>, updated_requests: Vec<UpdatedRequest>,
) -> Result<Option<CachedBatch>> { ) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest { let request = tonic::Request::new(FilterBatchRequest {
batch_id, batch_id,
request_ids, updated_requests,
}) })
.inject_context(); .inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner();

View File

@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens, StoppingCriteriaParameters, Tokens, UpdatedRequest,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;

View File

@ -10,7 +10,7 @@ use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings}; use v3::client::{DecodeTimings, PrefillTimings};
use v3::{ use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -84,12 +84,12 @@ impl ShardedClient {
pub async fn filter_batch( pub async fn filter_batch(
&mut self, &mut self,
batch_id: u64, batch_id: u64,
request_ids: Vec<u64>, updated_requests: Vec<UpdatedRequest>,
) -> Result<Option<CachedBatch>> { ) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) .map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone())))
.collect(); .collect();
// all shards return the same message // all shards return the same message
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()

View File

@ -506,6 +506,8 @@ pub enum InferError {
TemplateError(#[from] minijinja::Error), TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
#[error("Request could not be re-allocated: out of pages")]
OutOfPages,
} }
impl InferError { impl InferError {
@ -517,6 +519,7 @@ impl InferError {
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
InferError::OutOfPages => "out_of_pages",
} }
} }
} }

View File

@ -8,6 +8,12 @@ pub(crate) struct BlockAllocation {
block_allocator: BlockAllocator, block_allocator: BlockAllocator,
} }
impl BlockAllocation {
pub(crate) fn len(&self) -> usize {
self.slots.len()
}
}
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) self.block_allocator.free(self.blocks.clone())
@ -83,6 +89,8 @@ async fn block_allocator_task(
tokens, tokens,
response_sender, response_sender,
} => { } => {
// let tokens = 16;
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size { let (tokens, repeats) = match window_size {

View File

@ -34,7 +34,7 @@ pub(crate) struct Entry {
/// Block Allocation /// Block Allocation
pub block_allocation: Option<BlockAllocation>, pub block_allocation: Option<BlockAllocation>,
/// Current length (in tokens) of the request (prompt tokens + generated_tokens) /// Current length (in tokens) of the request (prompt tokens + generated_tokens)
pub current_length: u32 pub current_length: u32,
} }
/// Request Queue /// Request Queue

View File

@ -10,7 +10,7 @@ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest};
use text_generation_client::ClientError; use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
@ -288,7 +288,7 @@ async fn decode(
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
filter_update_allocations(client, entries).await; filter_update_allocations(entries).await;
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
@ -323,7 +323,7 @@ async fn filter_batch(
next_batch: Option<CachedBatch>, next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>, entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let mut batch = next_batch?; let batch = next_batch?;
// No need to filter // No need to filter
if batch.size as usize == entries.len() { if batch.size as usize == entries.len() {
@ -331,11 +331,7 @@ async fn filter_batch(
} }
let id = batch.id; let id = batch.id;
if entries.is_empty() {
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out // All requests have been filtered out
// Next batch is now empty // Next batch is now empty
// Clear it from the Python shards cache // Clear it from the Python shards cache
@ -344,8 +340,24 @@ async fn filter_batch(
None None
} else { } else {
// Filter Python shard cache // Filter Python shard cache
let updated_requests = entries
.iter()
.map(|(request_id, entry)| {
let (blocks, slots) = entry
.block_allocation
.as_ref()
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new()));
UpdatedRequest {
id: *request_id,
blocks,
slots,
}
})
.collect();
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap() client.filter_batch(id, updated_requests).await.unwrap()
} }
} }
@ -379,32 +391,36 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
} }
/// Check if block allocations need to be extended /// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with an OutOfPages finish reason /// If we don't have enough blocks, request will be filtered with an OutOfPages error
#[instrument(skip_all)] #[instrument(skip_all)]
async fn filter_update_allocations(client: &mut ShardedClient, entries: &mut IntMap<u64, Entry>) { async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
// let mut extend_entries = Vec::with_capacity(entries.len()); entries.retain(|request_id, entry| {
// let mut finish_entries = Vec::with_capacity(entries.len()); if entry.block_allocation.is_none() {
return true;
}
// for (request_id, entry) in entries.into_iter() { // We can unwrap since we already validated above that block_allocation is not None
// tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length); let mut block_allocation = entry.block_allocation.as_ref().unwrap();
//
// if let Some(block_allocation) = &mut entry.block_allocation { // Nothing to update
// tracing::info!("Allocation {:?}", block_allocation); if entry.current_length <= block_allocation.len() as u32 {
// return true;
// if entry.current_length > block_allocation.allocated_tokens { }
// // We need to add new blocks to this entry
// let remaining_tokens = block_allocation.total_tokens - entry.current_length; // Create and enter a span to link this function back to the entry
// match block_allocation.extend(remaining_tokens).await { let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
// true => { let err = InferError::OutOfPages;
// metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
// }, tracing::error!("{err}");
// false => {
// // unwrap_or is valid here as we don't care if the receiver is gone.
// } entry
// } .response_tx
// } .send(Err(err))
// } .unwrap_or(());
// }
false
});
} }
/// Send responses through the `entry` response channel /// Send responses through the `entry` response channel

View File

@ -1085,8 +1085,6 @@ pub(crate) enum FinishReason {
EndOfSequenceToken, EndOfSequenceToken,
#[schema(rename = "stop_sequence")] #[schema(rename = "stop_sequence")]
StopSequence, StopSequence,
#[schema(rename = "out_of_pages")]
OutOfPages
} }
impl std::fmt::Display for FinishReason { impl std::fmt::Display for FinishReason {
@ -1095,7 +1093,6 @@ impl std::fmt::Display for FinishReason {
FinishReason::Length => write!(f, "length"), FinishReason::Length => write!(f, "length"),
FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
FinishReason::StopSequence => write!(f, "stop_sequence"), FinishReason::StopSequence => write!(f, "stop_sequence"),
FinishReason::OutOfPages => write!(f, "out_of_pages"),
} }
} }
} }

View File

@ -1859,6 +1859,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS,
}; };
( (

View File

@ -158,7 +158,11 @@ class CausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["CausalLMBatch"]:
request_ids = [r.id for r in updated_requests]
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self): if len(request_ids) == len(self):
@ -746,7 +750,7 @@ class CausalLM(Model):
), ),
generated_text, generated_text,
top_tokens, top_tokens,
new_input_length new_input_length,
) )
generations.append(generation) generations.append(generation)

View File

@ -82,14 +82,10 @@ class FlashCausalLMBatch(Batch):
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor slot_indices: torch.Tensor
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# list of length b of list of length s_i
slots: List[List[int]]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots_tensor: torch.Tensor slots: torch.Tensor
max_seqlen: int max_seqlen: int
@ -183,7 +179,6 @@ class FlashCausalLMBatch(Batch):
max_blocks = 0 max_blocks = 0
block_tables = [] block_tables = []
slots = []
flat_slots = [] flat_slots = []
# Parse batch # Parse batch
@ -253,7 +248,6 @@ class FlashCausalLMBatch(Batch):
len(flat_slots) + input_length, len(flat_slots) + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
slots.append(request_slots)
flat_slots.extend(request_slots) flat_slots.extend(request_slots)
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -353,7 +347,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64 top_n_tokens, device=device, dtype=torch.int64
) )
slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
@ -370,10 +364,8 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
slots_tensor=slots_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -405,11 +397,13 @@ class FlashCausalLMBatch(Batch):
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(
if len(request_ids) == 0: self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["FlashCausalLMBatch"]:
if len(updated_requests) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
if len(request_ids) == len(self): if len(updated_requests) == len(self):
return self return self
device = self.input_ids.device device = self.input_ids.device
@ -425,7 +419,6 @@ class FlashCausalLMBatch(Batch):
requests = [] requests = []
block_tables = [] block_tables = []
slots = []
flat_slots = [] flat_slots = []
all_input_ids = [] all_input_ids = []
@ -439,7 +432,9 @@ class FlashCausalLMBatch(Batch):
num_blocks = 0 num_blocks = 0
max_blocks = 0 max_blocks = 0
for i, request_id in enumerate(request_ids): for i, request in enumerate(updated_requests):
request_id = request.id
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
indices.append(idx) indices.append(idx)
requests_idx_mapping[request_id] = i requests_idx_mapping[request_id] = i
@ -461,13 +456,12 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
request_block_table = self.block_tables[idx] request_block_table = request.blocks
num_blocks += len(request_block_table) num_blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
# List of slots allocated for this request # List of slots allocated for this request
request_slots = self.slots[idx] request_slots = request.slots
slots.append(request_slots)
# Index # Index
slot_indices.append(len(flat_slots) + request_input_length - 1) slot_indices.append(len(flat_slots) + request_input_length - 1)
@ -479,7 +473,6 @@ class FlashCausalLMBatch(Batch):
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
@ -487,10 +480,20 @@ class FlashCausalLMBatch(Batch):
self.speculative_ids[indices] if self.speculative_ids is not None else None self.speculative_ids[indices] if self.speculative_ids is not None else None
) )
# Create block_tables_tensor on CPU
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
# Allocate on GPU # Allocate on GPU
slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
# Move to GPU
block_tables_tensor = block_tables_tensor.to(device)
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
requests=requests, requests=requests,
@ -500,10 +503,8 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
slots_tensor=slots_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -538,7 +539,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0 max_seqlen = 0
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots_tensor) total_slots += len(b.slots)
num_blocks += b.num_blocks num_blocks += b.num_blocks
speculative_length = ( speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
@ -561,7 +562,7 @@ class FlashCausalLMBatch(Batch):
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots_tensor = batches[0].slots_tensor.new_empty(total_slots) slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size total_batch_size
@ -576,8 +577,6 @@ class FlashCausalLMBatch(Batch):
total_batch_size, total_batch_size,
) )
slots = []
block_tables = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -606,7 +605,7 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots_tensor) slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU) # Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
@ -614,7 +613,7 @@ class FlashCausalLMBatch(Batch):
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
@ -624,8 +623,6 @@ class FlashCausalLMBatch(Batch):
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
slots.extend(batch.slots)
block_tables.extend(batch.block_tables)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
@ -640,7 +637,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots_tensor) cumulative_slots += len(batch.slots)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
@ -665,10 +662,8 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
slots_tensor=slots_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -969,7 +964,7 @@ class FlashCausalLM(Model):
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots_tensor[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1008,7 +1003,7 @@ class FlashCausalLM(Model):
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots_tensor[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1350,7 +1345,7 @@ class FlashCausalLM(Model):
), ),
generated_text, generated_text,
top_tokens, top_tokens,
input_length + n_accepted_ids input_length + n_accepted_ids,
) )
generations.append(generation) generations.append(generation)

View File

@ -214,7 +214,11 @@ class IdeficsCausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["IdeficsCausalLMBatch"]:
request_ids = [r.id for r in updated_requests]
# It deletes requests from the batch. For instance when client lost connection # It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
@ -829,7 +833,7 @@ class IdeficsCausalLM(Model):
), ),
generated_text, generated_text,
top_tokens, top_tokens,
new_input_length new_input_length,
) )
generations.append(generation) generations.append(generation)

View File

@ -195,7 +195,11 @@ class MambaBatch(Batch):
max_tokens=max_tokens, max_tokens=max_tokens,
) )
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["MambaBatch"]:
request_ids = [r.id for r in updated_requests]
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self): if len(request_ids) == len(self):
@ -775,7 +779,7 @@ class Mamba(Model):
), ),
generated_text, generated_text,
top_tokens, top_tokens,
new_input_length new_input_length,
) )
generations.append(generation) generations.append(generation)

View File

@ -166,7 +166,11 @@ class Seq2SeqLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["Seq2SeqLMBatch"]:
request_ids = [r.id for r in updated_requests]
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self): if len(request_ids) == len(self):

View File

@ -28,7 +28,7 @@ class Batch(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def filter(self, request_ids: List[int]) -> "Batch": def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch":
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@ -122,8 +122,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch return batch
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]): def filter(
batch = super().filter(request_ids) self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["VlmCausalLMBatch"]:
batch = super().filter(updated_requests)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None

View File

@ -83,7 +83,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(request.batch_id) batch = self.cache.pop(request.batch_id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.") raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids) filtered_batch = batch.filter(request.updated_requests)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())