mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
re-working logic, wip
This commit is contained in:
parent
298bf31e69
commit
713d70b443
@ -1,6 +1,6 @@
|
|||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
@ -16,13 +16,23 @@ impl BlockAllocation {
|
|||||||
self.slots.len()
|
self.slots.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
|
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||||
let remaining_tokens =
|
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||||
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
|
|
||||||
self.block_allocator
|
match self.block_allocator.window_size {
|
||||||
.clone()
|
None => {
|
||||||
.extend(self, remaining_tokens)
|
self.blocks.push(block);
|
||||||
.await
|
self.slots.extend(slots);
|
||||||
|
}
|
||||||
|
Some(window_size) => {
|
||||||
|
if self.len() as u32 > window_size {
|
||||||
|
let total_tokens = self.prompt_tokens + self.decode_tokens;
|
||||||
|
|
||||||
|
let repeats = (total_tokens + window_size - 1) / window_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,8 +44,9 @@ impl Drop for BlockAllocation {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BlockAllocator {
|
pub(crate) struct BlockAllocator {
|
||||||
/// Channel to communicate with the background task
|
free_blocks: Arc<Mutex<Vec<u32>>>,
|
||||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BlockAllocator {
|
impl BlockAllocator {
|
||||||
@ -44,39 +55,105 @@ impl BlockAllocator {
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
let blocks = max_batch_total_tokens / block_size;
|
||||||
let (sender, receiver) = mpsc::unbounded_channel();
|
// Block 0 is reserved for health checks
|
||||||
|
let free_blocks: Vec<u32> = (1..blocks).collect();
|
||||||
// Launch background queue task
|
|
||||||
tokio::spawn(block_allocator_task(
|
|
||||||
max_batch_total_tokens / block_size,
|
|
||||||
block_size,
|
|
||||||
window_size,
|
|
||||||
receiver,
|
|
||||||
));
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
block_allocator: sender,
|
free_blocks: Arc::new(Mutex::new(free_blocks)),
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn allocate(
|
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> {
|
||||||
|
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||||
|
|
||||||
|
if free_blocks.is_empty() {
|
||||||
|
return Err(AllocationError::NotEnoughPages);
|
||||||
|
}
|
||||||
|
|
||||||
|
let block_id = free_blocks.pop().unwrap();
|
||||||
|
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
|
||||||
|
Ok((block_id, slots))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
||||||
|
/// For decode tokens, we allocate block by block
|
||||||
|
///
|
||||||
|
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
||||||
|
fn allocate(
|
||||||
|
&self,
|
||||||
|
prompt_tokens: u32,
|
||||||
|
decode_tokens: u32,
|
||||||
|
) -> Result<(Vec<u32>, Vec<u32>), AllocationError> {
|
||||||
|
// let decode_tokens = min(decode_tokens, self.block_size);
|
||||||
|
// let tokens = prompt_tokens + decode_tokens;
|
||||||
|
|
||||||
|
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||||
|
// prompt blocks + a single block for decode
|
||||||
|
let required_blocks = required_prompt_blocks + 1;
|
||||||
|
|
||||||
|
let (required_blocks, repeats) = match self.window_size {
|
||||||
|
// Nothing to do
|
||||||
|
None => (required_blocks, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
// Number of blocks needed for this window size
|
||||||
|
let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||||
|
// Number of times we will need to repeat blocks to cover the required allocation
|
||||||
|
let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks;
|
||||||
|
let required_blocks = min(required_blocks, window_size_required_blocks);
|
||||||
|
|
||||||
|
(required_blocks, repeats)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// if prompt + decode < window size => do nothing
|
||||||
|
/// if prompt + decode > window size => do normal until we reach window size then
|
||||||
|
|
||||||
|
// Apply window size
|
||||||
|
let (required_blocks, repeats) = {
|
||||||
|
let (tokens, repeats) = match self.window_size {
|
||||||
|
None => (tokens, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
|
let tokens = min(tokens, window_size);
|
||||||
|
(tokens, repeats as usize)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Pad to a multiple of block size
|
||||||
|
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||||
|
(required_blocks, repeats)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||||
|
|
||||||
|
if required_blocks > free_blocks.len() as u32 {
|
||||||
|
Err(AllocationError::NotEnoughPages)
|
||||||
|
} else {
|
||||||
|
let n_free_blocks = free_blocks.len();
|
||||||
|
let blocks =
|
||||||
|
free_blocks.split_off(n_free_blocks - required_blocks as usize);
|
||||||
|
let mut slots = Vec::with_capacity(
|
||||||
|
(required_blocks * self.block_size * repeats as u32) as usize,
|
||||||
|
);
|
||||||
|
|
||||||
|
for block_id in blocks.repeat(repeats).iter() {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((blocks, slots))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn block_allocation(
|
||||||
&self,
|
&self,
|
||||||
prompt_tokens: u32,
|
prompt_tokens: u32,
|
||||||
decode_tokens: u32,
|
decode_tokens: u32,
|
||||||
) -> Result<BlockAllocation, AllocationError> {
|
) -> Result<BlockAllocation, AllocationError> {
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
self.allocate_inner(prompt_tokens, decode_tokens)
|
||||||
self.block_allocator
|
|
||||||
.send(BlockAllocatorCommand::Allocate {
|
|
||||||
prompt_tokens,
|
|
||||||
decode_tokens,
|
|
||||||
response_sender,
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
response_receiver
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.map(|(blocks, slots)| BlockAllocation {
|
.map(|(blocks, slots)| BlockAllocation {
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
@ -86,103 +163,11 @@ impl BlockAllocator {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn extend(
|
|
||||||
&self,
|
|
||||||
block_allocation: &mut BlockAllocation,
|
|
||||||
tokens: u32,
|
|
||||||
) -> Result<(), AllocationError> {
|
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
|
||||||
self.block_allocator
|
|
||||||
.send(BlockAllocatorCommand::Allocate {
|
|
||||||
prompt_tokens: 0,
|
|
||||||
decode_tokens: tokens,
|
|
||||||
response_sender,
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let (blocks, slots) = response_receiver.await.unwrap()?;
|
|
||||||
block_allocation.blocks.extend(blocks);
|
|
||||||
block_allocation.slots.extend(slots);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
self.block_allocator
|
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
|
||||||
.send(BlockAllocatorCommand::Free { blocks })
|
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn block_allocator_task(
|
|
||||||
blocks: u32,
|
|
||||||
block_size: u32,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
|
||||||
) {
|
|
||||||
// Block 0 is reserved for health checks
|
|
||||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
|
||||||
match cmd {
|
|
||||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
|
||||||
BlockAllocatorCommand::Allocate {
|
|
||||||
prompt_tokens,
|
|
||||||
decode_tokens,
|
|
||||||
response_sender,
|
|
||||||
} => {
|
|
||||||
let decode_tokens = min(decode_tokens, block_size);
|
|
||||||
let tokens = prompt_tokens + decode_tokens;
|
|
||||||
|
|
||||||
// FIXME: window size is not working
|
|
||||||
// Apply window size
|
|
||||||
let (required_blocks, repeats) = {
|
|
||||||
let (tokens, repeats) = match window_size {
|
|
||||||
None => (tokens, 1),
|
|
||||||
Some(window_size) => {
|
|
||||||
let repeats = (tokens + window_size - 1) / window_size;
|
|
||||||
let tokens = min(tokens, window_size);
|
|
||||||
(tokens, repeats as usize)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Pad to a multiple of block size
|
|
||||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
|
||||||
(required_blocks, repeats)
|
|
||||||
};
|
|
||||||
|
|
||||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
|
||||||
Err(AllocationError::NotEnoughPages)
|
|
||||||
} else {
|
|
||||||
let blocks =
|
|
||||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
|
||||||
let mut slots = Vec::with_capacity(
|
|
||||||
(required_blocks * block_size * repeats as u32) as usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
for block_id in blocks.repeat(repeats).iter() {
|
|
||||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
|
||||||
slots.push(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok((blocks, slots))
|
|
||||||
};
|
|
||||||
response_sender.send(allocation).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum BlockAllocatorCommand {
|
|
||||||
Free {
|
|
||||||
blocks: Vec<u32>,
|
|
||||||
},
|
|
||||||
Allocate {
|
|
||||||
prompt_tokens: u32,
|
|
||||||
decode_tokens: u32,
|
|
||||||
#[allow(clippy::type_complexity)]
|
|
||||||
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum AllocationError {
|
pub enum AllocationError {
|
||||||
#[error("Not enough pages")]
|
#[error("Not enough pages")]
|
||||||
|
@ -284,7 +284,6 @@ impl State {
|
|||||||
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||||
match block_allocator
|
match block_allocator
|
||||||
.allocate(entry.request.input_length, decode_tokens)
|
.allocate(entry.request.input_length, decode_tokens)
|
||||||
.await
|
|
||||||
{
|
{
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
|
@ -428,8 +428,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
|||||||
.block_allocation
|
.block_allocation
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.expect("We checked that the block allocation exists above")
|
.expect("We checked that the block allocation exists above")
|
||||||
.extend(entry.cache_length)
|
.extend()
|
||||||
.await
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if extension.is_err() {
|
if extension.is_err() {
|
||||||
|
Loading…
Reference in New Issue
Block a user