working example

This commit is contained in:
OlivierDehaene 2024-06-05 18:47:16 +02:00
parent 1cc86930a6
commit 35f27cbcc1
4 changed files with 122 additions and 55 deletions

View File

@ -1,10 +1,13 @@
use std::cmp::min; use std::cmp::{max, min};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
prompt_tokens: u32,
decode_tokens: u32,
block_allocator: BlockAllocator, block_allocator: BlockAllocator,
} }
@ -12,6 +15,14 @@ impl BlockAllocation {
pub(crate) fn len(&self) -> usize { pub(crate) fn len(&self) -> usize {
self.slots.len() self.slots.len()
} }
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
self.block_allocator
.clone()
.extend(self, remaining_tokens)
.await
}
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
@ -48,11 +59,16 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, prompt_tokens,
decode_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
@ -63,10 +79,32 @@ impl BlockAllocator {
.map(|(blocks, slots)| BlockAllocation { .map(|(blocks, slots)| BlockAllocation {
blocks, blocks,
slots, slots,
prompt_tokens,
decode_tokens,
block_allocator: self.clone(), block_allocator: self.clone(),
}) })
} }
pub(crate) async fn extend(
&self,
block_allocation: &mut BlockAllocation,
tokens: u32,
) -> Result<(), AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
prompt_tokens: 0,
decode_tokens: tokens,
response_sender,
})
.unwrap();
let (blocks, slots) = response_receiver.await.unwrap()?;
block_allocation.blocks.extend(blocks);
block_allocation.slots.extend(slots);
Ok(())
}
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free { blocks })
@ -86,10 +124,12 @@ async fn block_allocator_task(
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, prompt_tokens,
decode_tokens,
response_sender, response_sender,
} => { } => {
// let tokens = 16; let decode_tokens = min(decode_tokens, block_size);
let tokens = prompt_tokens + decode_tokens;
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
@ -106,9 +146,8 @@ async fn block_allocator_task(
(required_blocks, repeats) (required_blocks, repeats)
}; };
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 { let allocation = if required_blocks > free_blocks.len() as u32 {
None Err(AllocationError::NotEnoughPages)
} else { } else {
let blocks = let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize); free_blocks.split_off(free_blocks.len() - required_blocks as usize);
@ -116,15 +155,12 @@ async fn block_allocator_task(
(required_blocks * block_size * repeats as u32) as usize, (required_blocks * block_size * repeats as u32) as usize,
); );
'slots: for block_id in blocks.repeat(repeats).iter() { for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) { for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s); slots.push(s);
if slots.len() == tokens {
break 'slots;
} }
} }
} Ok((blocks, slots))
Some((blocks, slots))
}; };
response_sender.send(allocation).unwrap(); response_sender.send(allocation).unwrap();
} }
@ -138,7 +174,15 @@ enum BlockAllocatorCommand {
blocks: Vec<u32>, blocks: Vec<u32>,
}, },
Allocate { Allocate {
tokens: u32, prompt_tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, decode_tokens: u32,
#[allow(clippy::type_complexity)]
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
}, },
} }
#[derive(Error, Debug)]
pub enum AllocationError {
#[error("Not enough pages")]
NotEnoughPages,
}

View File

@ -295,20 +295,20 @@ impl State {
break; break;
} }
let tokens = entry.request.input_length let decode_tokens =
+ entry.request.stopping_parameters.max_new_tokens entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
+ self.speculate match block_allocator
- 1; .allocate(entry.request.input_length, decode_tokens)
.await
match block_allocator.allocate(tokens).await { {
None => { Err(_) => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: not enough free blocks"); tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break 'entry_loop; break 'entry_loop;
} }
Some(block_allocation) => { Ok(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}"); tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation) Some(block_allocation)

View File

@ -247,7 +247,7 @@ async fn prefill(
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries, false).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
@ -288,10 +288,10 @@ async fn decode(
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
filter_update_allocations(entries).await; let updated = filter_update_allocations(entries).await;
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries, updated).await;
if let Some(concat_duration) = timings.concat { if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
@ -322,11 +322,12 @@ async fn filter_batch(
client: &mut ShardedClient, client: &mut ShardedClient,
next_batch: Option<CachedBatch>, next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>, entries: &IntMap<u64, Entry>,
force_update: bool,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let batch = next_batch?; let batch = next_batch?;
// No need to filter // No need to filter
if batch.size as usize == entries.len() { if batch.size as usize == entries.len() && !force_update {
return Some(batch); return Some(batch);
} }
@ -348,6 +349,7 @@ async fn filter_batch(
.as_ref() .as_ref()
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new())); .unwrap_or((Vec::new(), Vec::new()));
UpdatedRequest { UpdatedRequest {
id: *request_id, id: *request_id,
blocks, blocks,
@ -393,19 +395,44 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
/// Check if block allocations need to be extended /// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with an OutOfPages error /// If we don't have enough blocks, request will be filtered with an OutOfPages error
#[instrument(skip_all)] #[instrument(skip_all)]
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) { async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
entries.retain(|request_id, entry| { let ids: Vec<u64> = entries
if entry.block_allocation.is_none() { .iter()
return true; .filter_map(|(id, entry)| {
entry
.block_allocation
.as_ref()
.map(|block_allocation| {
if entry.current_length > block_allocation.len() as u32 {
// We need to re-allocate
Some(*id)
} else {
None
} }
})
.unwrap_or(None)
})
.collect();
// We can unwrap since we already validated above that block_allocation is not None for id in ids.iter() {
let mut block_allocation = entry.block_allocation.as_ref().unwrap(); // Get entry
// We can `expect` here as the request id should always be in the entries
let extension = {
let entry = entries
.get_mut(id)
.expect("ID not found in entries. This is a bug.");
entry
.block_allocation
.as_mut()
.unwrap()
.extend(entry.current_length)
.await
};
// Nothing to update if extension.is_err() {
if entry.current_length <= block_allocation.len() as u32 { let entry = entries
return true; .remove(id)
} .expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
@ -414,13 +441,12 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry entry.response_tx.send(Err(err)).unwrap_or(());
.response_tx }
.send(Err(err)) }
.unwrap_or(());
false // If ids is not empty, we need to update
}); !ids.is_empty()
} }
/// Send responses through the `entry` response channel /// Send responses through the `entry` response channel

View File

@ -402,9 +402,6 @@ class FlashCausalLMBatch(Batch):
) -> Optional["FlashCausalLMBatch"]: ) -> Optional["FlashCausalLMBatch"]:
if len(updated_requests) == 0: if len(updated_requests) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(updated_requests) == len(self):
return self
device = self.input_ids.device device = self.input_ids.device