mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
wip
This commit is contained in:
parent
18e77a5cc7
commit
1cc86930a6
@ -17,8 +17,6 @@ service TextGenerationService {
|
|||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
/// Update batch
|
|
||||||
rpc Update(UpdateRequest) returns (UpdateResponse);
|
|
||||||
/// Health check
|
/// Health check
|
||||||
rpc Health (HealthRequest) returns (HealthResponse);
|
rpc Health (HealthRequest) returns (HealthResponse);
|
||||||
}
|
}
|
||||||
@ -204,11 +202,20 @@ message Generation {
|
|||||||
uint32 current_length = 6;
|
uint32 current_length = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message UpdatedRequest {
|
||||||
|
/// Request ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// Paged attention blocks
|
||||||
|
repeated uint32 blocks = 2;
|
||||||
|
/// Paged attention slots
|
||||||
|
repeated uint32 slots = 3;
|
||||||
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
/// Batch ID
|
/// Batch ID
|
||||||
uint64 batch_id = 1;
|
uint64 batch_id = 1;
|
||||||
/// Requests to keep
|
/// Requests to keep
|
||||||
repeated uint64 request_ids = 2;
|
repeated UpdatedRequest updated_requests = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchResponse {
|
message FilterBatchResponse {
|
||||||
@ -255,26 +262,6 @@ message DecodeResponse {
|
|||||||
optional uint64 concat_ns = 6;
|
optional uint64 concat_ns = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ExtendedRequest {
|
|
||||||
/// Request ID
|
|
||||||
uint64 request_id = 1;
|
|
||||||
/// Paged attention blocks to add
|
|
||||||
repeated uint32 blocks = 2;
|
|
||||||
/// Paged attention slots to add
|
|
||||||
repeated uint32 slots = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
message UpdateRequest {
|
|
||||||
/// Batch ID
|
|
||||||
uint64 batch_id = 1;
|
|
||||||
/// Requests to update
|
|
||||||
repeated ExtendedRequest extend_requests = 2;
|
|
||||||
/// Requests to terminate
|
|
||||||
repeated uint64 terminated_request_ids = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
message UpdateResponse {}
|
|
||||||
|
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
|
@ -90,11 +90,11 @@ impl Client {
|
|||||||
pub async fn filter_batch(
|
pub async fn filter_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
request_ids: Vec<u64>,
|
updated_requests: Vec<UpdatedRequest>,
|
||||||
) -> Result<Option<CachedBatch>> {
|
) -> Result<Option<CachedBatch>> {
|
||||||
let request = tonic::Request::new(FilterBatchRequest {
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
batch_id,
|
batch_id,
|
||||||
request_ids,
|
updated_requests,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
@ -8,6 +8,6 @@ pub use client::Client;
|
|||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters, Tokens,
|
StoppingCriteriaParameters, Tokens, UpdatedRequest,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
@ -10,7 +10,7 @@ use tracing::instrument;
|
|||||||
use v3::client::{DecodeTimings, PrefillTimings};
|
use v3::client::{DecodeTimings, PrefillTimings};
|
||||||
use v3::{
|
use v3::{
|
||||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -84,12 +84,12 @@ impl ShardedClient {
|
|||||||
pub async fn filter_batch(
|
pub async fn filter_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
request_ids: Vec<u64>,
|
updated_requests: Vec<UpdatedRequest>,
|
||||||
) -> Result<Option<CachedBatch>> {
|
) -> Result<Option<CachedBatch>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
.map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// all shards return the same message
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
|
@ -506,6 +506,8 @@ pub enum InferError {
|
|||||||
TemplateError(#[from] minijinja::Error),
|
TemplateError(#[from] minijinja::Error),
|
||||||
#[error("Tool error: {0}")]
|
#[error("Tool error: {0}")]
|
||||||
ToolError(String),
|
ToolError(String),
|
||||||
|
#[error("Request could not be re-allocated: out of pages")]
|
||||||
|
OutOfPages,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferError {
|
impl InferError {
|
||||||
@ -517,6 +519,7 @@ impl InferError {
|
|||||||
InferError::IncompleteGeneration => "incomplete_generation",
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
InferError::TemplateError(_) => "template_error",
|
InferError::TemplateError(_) => "template_error",
|
||||||
InferError::ToolError(_) => "tool_error",
|
InferError::ToolError(_) => "tool_error",
|
||||||
|
InferError::OutOfPages => "out_of_pages",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,12 @@ pub(crate) struct BlockAllocation {
|
|||||||
block_allocator: BlockAllocator,
|
block_allocator: BlockAllocator,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl BlockAllocation {
|
||||||
|
pub(crate) fn len(&self) -> usize {
|
||||||
|
self.slots.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.block_allocator.free(self.blocks.clone())
|
self.block_allocator.free(self.blocks.clone())
|
||||||
@ -83,6 +89,8 @@ async fn block_allocator_task(
|
|||||||
tokens,
|
tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
} => {
|
} => {
|
||||||
|
// let tokens = 16;
|
||||||
|
|
||||||
// Apply window size
|
// Apply window size
|
||||||
let (required_blocks, repeats) = {
|
let (required_blocks, repeats) = {
|
||||||
let (tokens, repeats) = match window_size {
|
let (tokens, repeats) = match window_size {
|
||||||
|
@ -34,7 +34,7 @@ pub(crate) struct Entry {
|
|||||||
/// Block Allocation
|
/// Block Allocation
|
||||||
pub block_allocation: Option<BlockAllocation>,
|
pub block_allocation: Option<BlockAllocation>,
|
||||||
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
|
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
|
||||||
pub current_length: u32
|
pub current_length: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
|
@ -10,7 +10,7 @@ use std::sync::{
|
|||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest};
|
||||||
use text_generation_client::ClientError;
|
use text_generation_client::ClientError;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||||
@ -288,7 +288,7 @@ async fn decode(
|
|||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
filter_update_allocations(client, entries).await;
|
filter_update_allocations(entries).await;
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
@ -323,7 +323,7 @@ async fn filter_batch(
|
|||||||
next_batch: Option<CachedBatch>,
|
next_batch: Option<CachedBatch>,
|
||||||
entries: &IntMap<u64, Entry>,
|
entries: &IntMap<u64, Entry>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let mut batch = next_batch?;
|
let batch = next_batch?;
|
||||||
|
|
||||||
// No need to filter
|
// No need to filter
|
||||||
if batch.size as usize == entries.len() {
|
if batch.size as usize == entries.len() {
|
||||||
@ -331,11 +331,7 @@ async fn filter_batch(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let id = batch.id;
|
let id = batch.id;
|
||||||
|
if entries.is_empty() {
|
||||||
// Retain only requests that are still in entries
|
|
||||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
|
||||||
|
|
||||||
if batch.request_ids.is_empty() {
|
|
||||||
// All requests have been filtered out
|
// All requests have been filtered out
|
||||||
// Next batch is now empty
|
// Next batch is now empty
|
||||||
// Clear it from the Python shards cache
|
// Clear it from the Python shards cache
|
||||||
@ -344,8 +340,24 @@ async fn filter_batch(
|
|||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// Filter Python shard cache
|
// Filter Python shard cache
|
||||||
|
let updated_requests = entries
|
||||||
|
.iter()
|
||||||
|
.map(|(request_id, entry)| {
|
||||||
|
let (blocks, slots) = entry
|
||||||
|
.block_allocation
|
||||||
|
.as_ref()
|
||||||
|
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
||||||
|
.unwrap_or((Vec::new(), Vec::new()));
|
||||||
|
UpdatedRequest {
|
||||||
|
id: *request_id,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
client.filter_batch(id, updated_requests).await.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -379,32 +391,36 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if block allocations need to be extended
|
/// Check if block allocations need to be extended
|
||||||
/// If we don't have enough blocks, request will be filtered with an OutOfPages finish reason
|
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn filter_update_allocations(client: &mut ShardedClient, entries: &mut IntMap<u64, Entry>) {
|
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
|
||||||
// let mut extend_entries = Vec::with_capacity(entries.len());
|
entries.retain(|request_id, entry| {
|
||||||
// let mut finish_entries = Vec::with_capacity(entries.len());
|
if entry.block_allocation.is_none() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// for (request_id, entry) in entries.into_iter() {
|
// We can unwrap since we already validated above that block_allocation is not None
|
||||||
// tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length);
|
let mut block_allocation = entry.block_allocation.as_ref().unwrap();
|
||||||
//
|
|
||||||
// if let Some(block_allocation) = &mut entry.block_allocation {
|
// Nothing to update
|
||||||
// tracing::info!("Allocation {:?}", block_allocation);
|
if entry.current_length <= block_allocation.len() as u32 {
|
||||||
//
|
return true;
|
||||||
// if entry.current_length > block_allocation.allocated_tokens {
|
}
|
||||||
// // We need to add new blocks to this entry
|
|
||||||
// let remaining_tokens = block_allocation.total_tokens - entry.current_length;
|
// Create and enter a span to link this function back to the entry
|
||||||
// match block_allocation.extend(remaining_tokens).await {
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
// true => {
|
let err = InferError::OutOfPages;
|
||||||
//
|
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
|
||||||
// },
|
tracing::error!("{err}");
|
||||||
// false => {
|
|
||||||
//
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
// }
|
entry
|
||||||
// }
|
.response_tx
|
||||||
// }
|
.send(Err(err))
|
||||||
// }
|
.unwrap_or(());
|
||||||
// }
|
|
||||||
|
false
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
/// Send responses through the `entry` response channel
|
||||||
|
@ -1085,8 +1085,6 @@ pub(crate) enum FinishReason {
|
|||||||
EndOfSequenceToken,
|
EndOfSequenceToken,
|
||||||
#[schema(rename = "stop_sequence")]
|
#[schema(rename = "stop_sequence")]
|
||||||
StopSequence,
|
StopSequence,
|
||||||
#[schema(rename = "out_of_pages")]
|
|
||||||
OutOfPages
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for FinishReason {
|
impl std::fmt::Display for FinishReason {
|
||||||
@ -1095,7 +1093,6 @@ impl std::fmt::Display for FinishReason {
|
|||||||
FinishReason::Length => write!(f, "length"),
|
FinishReason::Length => write!(f, "length"),
|
||||||
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
||||||
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
||||||
FinishReason::OutOfPages => write!(f, "out_of_pages"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1859,6 +1859,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
@ -158,7 +158,11 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
def filter(
|
||||||
|
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||||
|
) -> Optional["CausalLMBatch"]:
|
||||||
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
@ -746,7 +750,7 @@ class CausalLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
new_input_length
|
new_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -82,14 +82,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
slot_indices: torch.Tensor
|
slot_indices: torch.Tensor
|
||||||
|
|
||||||
# list of length b of list of length s_i // block_size
|
|
||||||
block_tables: List[List[int]]
|
|
||||||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# list of length b of list of length s_i
|
|
||||||
slots: List[List[int]]
|
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots_tensor: torch.Tensor
|
slots: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
@ -183,7 +179,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
|
||||||
flat_slots = []
|
flat_slots = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
@ -253,7 +248,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
len(flat_slots) + input_length,
|
len(flat_slots) + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
slots.append(request_slots)
|
|
||||||
flat_slots.extend(request_slots)
|
flat_slots.extend(request_slots)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
@ -353,7 +347,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
@ -370,10 +364,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
slots_tensor=slots_tensor,
|
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
@ -405,11 +397,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(
|
||||||
if len(request_ids) == 0:
|
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||||
|
) -> Optional["FlashCausalLMBatch"]:
|
||||||
|
if len(updated_requests) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
if len(request_ids) == len(self):
|
if len(updated_requests) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
@ -425,7 +419,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
|
||||||
flat_slots = []
|
flat_slots = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
@ -439,7 +432,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request in enumerate(updated_requests):
|
||||||
|
request_id = request.id
|
||||||
|
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
requests_idx_mapping[request_id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
@ -461,13 +456,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
request_block_table = self.block_tables[idx]
|
request_block_table = request.blocks
|
||||||
num_blocks += len(request_block_table)
|
num_blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
|
|
||||||
# List of slots allocated for this request
|
# List of slots allocated for this request
|
||||||
request_slots = self.slots[idx]
|
request_slots = request.slots
|
||||||
slots.append(request_slots)
|
|
||||||
|
|
||||||
# Index
|
# Index
|
||||||
slot_indices.append(len(flat_slots) + request_input_length - 1)
|
slot_indices.append(len(flat_slots) + request_input_length - 1)
|
||||||
@ -479,7 +473,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
@ -487,10 +480,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create block_tables_tensor on CPU
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
for i, request_blocks in enumerate(block_tables):
|
||||||
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
|
|
||||||
# Allocate on GPU
|
# Allocate on GPU
|
||||||
slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
||||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
# Move to GPU
|
||||||
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
@ -500,10 +503,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
slots_tensor=slots_tensor,
|
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -538,7 +539,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots_tensor)
|
total_slots += len(b.slots)
|
||||||
num_blocks += b.num_blocks
|
num_blocks += b.num_blocks
|
||||||
speculative_length = (
|
speculative_length = (
|
||||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
@ -561,7 +562,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
slots_tensor = batches[0].slots_tensor.new_empty(total_slots)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
@ -576,8 +577,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
total_batch_size,
|
total_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
slots = []
|
|
||||||
block_tables = []
|
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -606,7 +605,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
slots_start_index = cumulative_slots
|
slots_start_index = cumulative_slots
|
||||||
slots_end_index = cumulative_slots + len(batch.slots_tensor)
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
@ -614,7 +613,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
@ -624,8 +623,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
slots.extend(batch.slots)
|
|
||||||
block_tables.extend(batch.block_tables)
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
@ -640,7 +637,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
cumulative_slots += len(batch.slots_tensor)
|
cumulative_slots += len(batch.slots)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
@ -665,10 +662,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
slots_tensor=slots_tensor,
|
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -969,7 +964,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots_tensor[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
@ -1008,7 +1003,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots_tensor[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
@ -1350,7 +1345,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
input_length + n_accepted_ids
|
input_length + n_accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -214,7 +214,11 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
def filter(
|
||||||
|
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||||
|
) -> Optional["IdeficsCausalLMBatch"]:
|
||||||
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
# It deletes requests from the batch. For instance when client lost connection
|
# It deletes requests from the batch. For instance when client lost connection
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
@ -829,7 +833,7 @@ class IdeficsCausalLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
new_input_length
|
new_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -195,7 +195,11 @@ class MambaBatch(Batch):
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
|
def filter(
|
||||||
|
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||||
|
) -> Optional["MambaBatch"]:
|
||||||
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
@ -775,7 +779,7 @@ class Mamba(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
new_input_length
|
new_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -166,7 +166,11 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
|
def filter(
|
||||||
|
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||||
|
) -> Optional["Seq2SeqLMBatch"]:
|
||||||
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
|
@ -28,7 +28,7 @@ class Batch(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def filter(self, request_ids: List[int]) -> "Batch":
|
def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -122,8 +122,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(
|
||||||
batch = super().filter(request_ids)
|
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||||
|
) -> Optional["VlmCausalLMBatch"]:
|
||||||
|
batch = super().filter(updated_requests)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
|
@ -83,7 +83,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.request_ids)
|
filtered_batch = batch.filter(request.updated_requests)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
Loading…
Reference in New Issue
Block a user