mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
small refactor
This commit is contained in:
parent
713d70b443
commit
6983ec9537
@ -1,44 +1,55 @@
|
|||||||
use std::cmp::min;
|
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
pub blocks: Vec<u32>,
|
allocated_blocks: Vec<u32>,
|
||||||
pub slots: Vec<u32>,
|
allocated_slots: Vec<u32>,
|
||||||
prompt_tokens: u32,
|
required_blocks: usize,
|
||||||
decode_tokens: u32,
|
required_slots: usize,
|
||||||
block_allocator: BlockAllocator,
|
block_allocator: BlockAllocator,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BlockAllocation {
|
impl BlockAllocation {
|
||||||
pub(crate) fn len(&self) -> usize {
|
pub(crate) fn len(&self) -> usize {
|
||||||
self.slots.len()
|
self.allocated_slots.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn blocks(&self) -> &[u32] {
|
||||||
|
&self.allocated_blocks
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn slots(&self) -> &[u32] {
|
||||||
|
&self.allocated_slots
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||||
|
// Add block and slots to current allocation
|
||||||
|
self.allocated_blocks.push(block);
|
||||||
|
self.allocated_slots.extend(slots);
|
||||||
|
|
||||||
match self.block_allocator.window_size {
|
if let Some(window_size) = self.block_allocator.window_size {
|
||||||
None => {
|
// if we have more slots than the window size,
|
||||||
self.blocks.push(block);
|
// we will never need to re-allocate and we can just repeat the blocks/slots
|
||||||
self.slots.extend(slots);
|
let window_size = window_size as usize;
|
||||||
}
|
if self.len() > window_size {
|
||||||
Some(window_size) => {
|
let repeats = (self.required_slots + window_size - 1) / window_size;
|
||||||
if self.len() as u32 > window_size {
|
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
|
||||||
let total_tokens = self.prompt_tokens + self.decode_tokens;
|
self.allocated_blocks.truncate(self.required_blocks);
|
||||||
|
self.allocated_slots = self.allocated_slots.repeat(repeats);
|
||||||
let repeats = (total_tokens + window_size - 1) / window_size;
|
self.allocated_slots.truncate(self.required_slots);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.block_allocator.free(self.blocks.clone())
|
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
|
||||||
|
self.block_allocator.free(allocated_blocks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,85 +93,76 @@ impl BlockAllocator {
|
|||||||
/// For decode tokens, we allocate block by block
|
/// For decode tokens, we allocate block by block
|
||||||
///
|
///
|
||||||
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
||||||
fn allocate(
|
|
||||||
&self,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
decode_tokens: u32,
|
|
||||||
) -> Result<(Vec<u32>, Vec<u32>), AllocationError> {
|
|
||||||
// let decode_tokens = min(decode_tokens, self.block_size);
|
|
||||||
// let tokens = prompt_tokens + decode_tokens;
|
|
||||||
|
|
||||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
|
||||||
// prompt blocks + a single block for decode
|
|
||||||
let required_blocks = required_prompt_blocks + 1;
|
|
||||||
|
|
||||||
let (required_blocks, repeats) = match self.window_size {
|
|
||||||
// Nothing to do
|
|
||||||
None => (required_blocks, 1),
|
|
||||||
Some(window_size) => {
|
|
||||||
// Number of blocks needed for this window size
|
|
||||||
let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size;
|
|
||||||
// Number of times we will need to repeat blocks to cover the required allocation
|
|
||||||
let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks;
|
|
||||||
let required_blocks = min(required_blocks, window_size_required_blocks);
|
|
||||||
|
|
||||||
(required_blocks, repeats)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/// if prompt + decode < window size => do nothing
|
|
||||||
/// if prompt + decode > window size => do normal until we reach window size then
|
|
||||||
|
|
||||||
// Apply window size
|
|
||||||
let (required_blocks, repeats) = {
|
|
||||||
let (tokens, repeats) = match self.window_size {
|
|
||||||
None => (tokens, 1),
|
|
||||||
Some(window_size) => {
|
|
||||||
let repeats = (tokens + window_size - 1) / window_size;
|
|
||||||
let tokens = min(tokens, window_size);
|
|
||||||
(tokens, repeats as usize)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Pad to a multiple of block size
|
|
||||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
|
||||||
(required_blocks, repeats)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
|
||||||
|
|
||||||
if required_blocks > free_blocks.len() as u32 {
|
|
||||||
Err(AllocationError::NotEnoughPages)
|
|
||||||
} else {
|
|
||||||
let n_free_blocks = free_blocks.len();
|
|
||||||
let blocks =
|
|
||||||
free_blocks.split_off(n_free_blocks - required_blocks as usize);
|
|
||||||
let mut slots = Vec::with_capacity(
|
|
||||||
(required_blocks * self.block_size * repeats as u32) as usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
for block_id in blocks.repeat(repeats).iter() {
|
|
||||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
|
||||||
slots.push(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok((blocks, slots))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn block_allocation(
|
pub(crate) fn block_allocation(
|
||||||
&self,
|
&self,
|
||||||
prompt_tokens: u32,
|
prompt_tokens: u32,
|
||||||
decode_tokens: u32,
|
decode_tokens: u32,
|
||||||
) -> Result<BlockAllocation, AllocationError> {
|
) -> Result<BlockAllocation, AllocationError> {
|
||||||
self.allocate_inner(prompt_tokens, decode_tokens)
|
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||||
.map(|(blocks, slots)| BlockAllocation {
|
// prompt blocks + a single block for decode
|
||||||
blocks,
|
let required_blocks = required_prompt_blocks + 1;
|
||||||
slots,
|
|
||||||
prompt_tokens,
|
let (clipped_required_blocks, repeats) = match self.window_size {
|
||||||
decode_tokens,
|
// Nothing to do
|
||||||
|
None => (required_blocks, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
// Number of blocks for this window size
|
||||||
|
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||||
|
|
||||||
|
if required_blocks > window_size_blocks {
|
||||||
|
// Number of times we will need to repeat blocks to cover the required allocation
|
||||||
|
let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks;
|
||||||
|
(window_size_blocks, repeats)
|
||||||
|
} else {
|
||||||
|
(required_blocks, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let repeats = repeats as usize;
|
||||||
|
let required_blocks = required_blocks as usize;
|
||||||
|
let clipped_required_blocks = clipped_required_blocks as usize;
|
||||||
|
|
||||||
|
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||||
|
|
||||||
|
if clipped_required_blocks > free_blocks.len() {
|
||||||
|
Err(AllocationError::NotEnoughPages)
|
||||||
|
} else {
|
||||||
|
let n_free_blocks = free_blocks.len();
|
||||||
|
let allocated_blocks =
|
||||||
|
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
||||||
|
|
||||||
|
let allocated_blocks = if repeats != 1 {
|
||||||
|
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
||||||
|
allocated_blocks.truncate(required_blocks);
|
||||||
|
allocated_blocks
|
||||||
|
} else {
|
||||||
|
allocated_blocks
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut allocated_slots = Vec::with_capacity(
|
||||||
|
allocated_blocks.len() * self.block_size as usize * repeats,
|
||||||
|
);
|
||||||
|
|
||||||
|
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
||||||
|
|
||||||
|
'slots: for block_id in allocated_blocks.iter() {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
allocated_slots.push(s);
|
||||||
|
if allocated_slots.len() > required_slots {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(BlockAllocation {
|
||||||
|
allocated_blocks,
|
||||||
|
allocated_slots,
|
||||||
|
required_blocks,
|
||||||
|
required_slots,
|
||||||
block_allocator: self.clone(),
|
block_allocator: self.clone(),
|
||||||
})
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
|
@ -283,7 +283,7 @@ impl State {
|
|||||||
let decode_tokens =
|
let decode_tokens =
|
||||||
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||||
match block_allocator
|
match block_allocator
|
||||||
.allocate(entry.request.input_length, decode_tokens)
|
.block_allocation(entry.request.input_length, decode_tokens)
|
||||||
{
|
{
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
@ -294,7 +294,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
Ok(block_allocation) => {
|
Ok(block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
max_blocks = max(max_blocks, block_allocation.blocks().len() as u32);
|
||||||
Some(block_allocation)
|
Some(block_allocation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,8 +313,8 @@ impl State {
|
|||||||
let (blocks, slots) = match &block_allocation {
|
let (blocks, slots) = match &block_allocation {
|
||||||
None => (Vec::new(), Vec::new()),
|
None => (Vec::new(), Vec::new()),
|
||||||
Some(block_allocation) => (
|
Some(block_allocation) => (
|
||||||
block_allocation.blocks.clone(),
|
block_allocation.blocks().to_vec(),
|
||||||
block_allocation.slots.clone(),
|
block_allocation.slots().to_vec(),
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -347,7 +347,7 @@ async fn filter_batch(
|
|||||||
let (blocks, slots) = entry
|
let (blocks, slots) = entry
|
||||||
.block_allocation
|
.block_allocation
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
||||||
.unwrap_or((Vec::new(), Vec::new()));
|
.unwrap_or((Vec::new(), Vec::new()));
|
||||||
|
|
||||||
KeptRequest {
|
KeptRequest {
|
||||||
|
Loading…
Reference in New Issue
Block a user