mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-29 22:12:07 +00:00
working example
This commit is contained in:
parent
1cc86930a6
commit
35f27cbcc1
@ -1,10 +1,13 @@
|
|||||||
use std::cmp::min;
|
use std::cmp::{max, min};
|
||||||
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
pub blocks: Vec<u32>,
|
pub blocks: Vec<u32>,
|
||||||
pub slots: Vec<u32>,
|
pub slots: Vec<u32>,
|
||||||
|
prompt_tokens: u32,
|
||||||
|
decode_tokens: u32,
|
||||||
block_allocator: BlockAllocator,
|
block_allocator: BlockAllocator,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -12,6 +15,14 @@ impl BlockAllocation {
|
|||||||
pub(crate) fn len(&self) -> usize {
|
pub(crate) fn len(&self) -> usize {
|
||||||
self.slots.len()
|
self.slots.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
|
||||||
|
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
|
||||||
|
self.block_allocator
|
||||||
|
.clone()
|
||||||
|
.extend(self, remaining_tokens)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
@ -48,11 +59,16 @@ impl BlockAllocator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
pub(crate) async fn allocate(
|
||||||
|
&self,
|
||||||
|
prompt_tokens: u32,
|
||||||
|
decode_tokens: u32,
|
||||||
|
) -> Result<BlockAllocation, AllocationError> {
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.send(BlockAllocatorCommand::Allocate {
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
tokens,
|
prompt_tokens,
|
||||||
|
decode_tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -63,10 +79,32 @@ impl BlockAllocator {
|
|||||||
.map(|(blocks, slots)| BlockAllocation {
|
.map(|(blocks, slots)| BlockAllocation {
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
|
prompt_tokens,
|
||||||
|
decode_tokens,
|
||||||
block_allocator: self.clone(),
|
block_allocator: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn extend(
|
||||||
|
&self,
|
||||||
|
block_allocation: &mut BlockAllocation,
|
||||||
|
tokens: u32,
|
||||||
|
) -> Result<(), AllocationError> {
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
decode_tokens: tokens,
|
||||||
|
response_sender,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (blocks, slots) = response_receiver.await.unwrap()?;
|
||||||
|
block_allocation.blocks.extend(blocks);
|
||||||
|
block_allocation.slots.extend(slots);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.send(BlockAllocatorCommand::Free { blocks })
|
.send(BlockAllocatorCommand::Free { blocks })
|
||||||
@ -86,10 +124,12 @@ async fn block_allocator_task(
|
|||||||
match cmd {
|
match cmd {
|
||||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||||
BlockAllocatorCommand::Allocate {
|
BlockAllocatorCommand::Allocate {
|
||||||
tokens,
|
prompt_tokens,
|
||||||
|
decode_tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
} => {
|
} => {
|
||||||
// let tokens = 16;
|
let decode_tokens = min(decode_tokens, block_size);
|
||||||
|
let tokens = prompt_tokens + decode_tokens;
|
||||||
|
|
||||||
// Apply window size
|
// Apply window size
|
||||||
let (required_blocks, repeats) = {
|
let (required_blocks, repeats) = {
|
||||||
@ -106,9 +146,8 @@ async fn block_allocator_task(
|
|||||||
(required_blocks, repeats)
|
(required_blocks, repeats)
|
||||||
};
|
};
|
||||||
|
|
||||||
let tokens = tokens as usize;
|
|
||||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||||
None
|
Err(AllocationError::NotEnoughPages)
|
||||||
} else {
|
} else {
|
||||||
let blocks =
|
let blocks =
|
||||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||||
@ -116,15 +155,12 @@ async fn block_allocator_task(
|
|||||||
(required_blocks * block_size * repeats as u32) as usize,
|
(required_blocks * block_size * repeats as u32) as usize,
|
||||||
);
|
);
|
||||||
|
|
||||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
for block_id in blocks.repeat(repeats).iter() {
|
||||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||||
slots.push(s);
|
slots.push(s);
|
||||||
if slots.len() == tokens {
|
|
||||||
break 'slots;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
Ok((blocks, slots))
|
||||||
Some((blocks, slots))
|
|
||||||
};
|
};
|
||||||
response_sender.send(allocation).unwrap();
|
response_sender.send(allocation).unwrap();
|
||||||
}
|
}
|
||||||
@ -138,7 +174,15 @@ enum BlockAllocatorCommand {
|
|||||||
blocks: Vec<u32>,
|
blocks: Vec<u32>,
|
||||||
},
|
},
|
||||||
Allocate {
|
Allocate {
|
||||||
tokens: u32,
|
prompt_tokens: u32,
|
||||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
decode_tokens: u32,
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum AllocationError {
|
||||||
|
#[error("Not enough pages")]
|
||||||
|
NotEnoughPages,
|
||||||
|
}
|
||||||
|
@ -295,20 +295,20 @@ impl State {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let tokens = entry.request.input_length
|
let decode_tokens =
|
||||||
+ entry.request.stopping_parameters.max_new_tokens
|
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
|
||||||
+ self.speculate
|
match block_allocator
|
||||||
- 1;
|
.allocate(entry.request.input_length, decode_tokens)
|
||||||
|
.await
|
||||||
match block_allocator.allocate(tokens).await {
|
{
|
||||||
None => {
|
Err(_) => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
Some(block_allocation) => {
|
Ok(block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
Some(block_allocation)
|
Some(block_allocation)
|
||||||
|
@ -247,7 +247,7 @@ async fn prefill(
|
|||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries, false).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||||
@ -288,10 +288,10 @@ async fn decode(
|
|||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
filter_update_allocations(entries).await;
|
let updated = filter_update_allocations(entries).await;
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries, updated).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||||
@ -322,11 +322,12 @@ async fn filter_batch(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
next_batch: Option<CachedBatch>,
|
next_batch: Option<CachedBatch>,
|
||||||
entries: &IntMap<u64, Entry>,
|
entries: &IntMap<u64, Entry>,
|
||||||
|
force_update: bool,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let batch = next_batch?;
|
let batch = next_batch?;
|
||||||
|
|
||||||
// No need to filter
|
// No need to filter
|
||||||
if batch.size as usize == entries.len() {
|
if batch.size as usize == entries.len() && !force_update {
|
||||||
return Some(batch);
|
return Some(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -348,6 +349,7 @@ async fn filter_batch(
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
||||||
.unwrap_or((Vec::new(), Vec::new()));
|
.unwrap_or((Vec::new(), Vec::new()));
|
||||||
|
|
||||||
UpdatedRequest {
|
UpdatedRequest {
|
||||||
id: *request_id,
|
id: *request_id,
|
||||||
blocks,
|
blocks,
|
||||||
@ -393,19 +395,44 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
/// Check if block allocations need to be extended
|
/// Check if block allocations need to be extended
|
||||||
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
|
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
|
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
||||||
entries.retain(|request_id, entry| {
|
let ids: Vec<u64> = entries
|
||||||
if entry.block_allocation.is_none() {
|
.iter()
|
||||||
return true;
|
.filter_map(|(id, entry)| {
|
||||||
|
entry
|
||||||
|
.block_allocation
|
||||||
|
.as_ref()
|
||||||
|
.map(|block_allocation| {
|
||||||
|
if entry.current_length > block_allocation.len() as u32 {
|
||||||
|
// We need to re-allocate
|
||||||
|
Some(*id)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(None)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
// We can unwrap since we already validated above that block_allocation is not None
|
for id in ids.iter() {
|
||||||
let mut block_allocation = entry.block_allocation.as_ref().unwrap();
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let extension = {
|
||||||
|
let entry = entries
|
||||||
|
.get_mut(id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
entry
|
||||||
|
.block_allocation
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.extend(entry.current_length)
|
||||||
|
.await
|
||||||
|
};
|
||||||
|
|
||||||
// Nothing to update
|
if extension.is_err() {
|
||||||
if entry.current_length <= block_allocation.len() as u32 {
|
let entry = entries
|
||||||
return true;
|
.remove(id)
|
||||||
}
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
@ -414,13 +441,12 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
|
|||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry.response_tx.send(Err(err)).unwrap_or(());
|
||||||
.response_tx
|
}
|
||||||
.send(Err(err))
|
}
|
||||||
.unwrap_or(());
|
|
||||||
|
|
||||||
false
|
// If ids is not empty, we need to update
|
||||||
});
|
!ids.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
/// Send responses through the `entry` response channel
|
||||||
|
@ -402,9 +402,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
) -> Optional["FlashCausalLMBatch"]:
|
) -> Optional["FlashCausalLMBatch"]:
|
||||||
if len(updated_requests) == 0:
|
if len(updated_requests) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
|
||||||
if len(updated_requests) == len(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user