mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-15 05:42:07 +00:00
fix windowing
This commit is contained in:
parent
37266e2dbb
commit
c2fb459bc1
@ -24,6 +24,9 @@ impl BlockAllocation {
|
|||||||
&self.allocated_slots
|
&self.allocated_slots
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extend an allocation by adding a new block
|
||||||
|
/// If the allocation length > window size, repeats blocks and slots to cover the
|
||||||
|
/// whole `required_blocks` and `required_slots`
|
||||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||||
// Add block and slots to current allocation
|
// Add block and slots to current allocation
|
||||||
@ -48,6 +51,7 @@ impl BlockAllocation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
|
/// Free the blocks
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
|
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
|
||||||
self.block_allocator.free(allocated_blocks)
|
self.block_allocator.free(allocated_blocks)
|
||||||
@ -114,66 +118,71 @@ impl BlockAllocator {
|
|||||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||||
// prompt blocks + a single block for decode
|
// prompt blocks + a single block for decode
|
||||||
let required_blocks = required_prompt_blocks + 1;
|
let required_blocks = required_prompt_blocks + 1;
|
||||||
|
let required_slots = required_blocks * self.block_size;
|
||||||
|
|
||||||
|
// Slots and blocks required for the whole request
|
||||||
|
let total_slots = prompt_tokens + decode_tokens;
|
||||||
|
let total_required_blocks = (total_slots + self.block_size - 1) / self.block_size;
|
||||||
|
|
||||||
let (clipped_required_blocks, repeats) = match self.window_size {
|
let (clipped_required_blocks, repeats) = match self.window_size {
|
||||||
// Nothing to do
|
Some(window_size) if required_slots >= window_size => {
|
||||||
None => (required_blocks, 1),
|
|
||||||
Some(window_size) => {
|
|
||||||
// Number of blocks for this window size
|
// Number of blocks for this window size
|
||||||
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
|
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||||
|
// Number of times we will need to repeat blocks to cover the total allocation
|
||||||
if required_blocks > window_size_blocks {
|
let repeats = (total_slots + window_size - 1) / window_size;
|
||||||
// Number of times we will need to repeat blocks to cover the required allocation
|
(window_size_blocks, repeats)
|
||||||
let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks;
|
|
||||||
(window_size_blocks, repeats)
|
|
||||||
} else {
|
|
||||||
(required_blocks, 1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// Nothing to do
|
||||||
|
_ => (required_blocks, 1),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Scoped to drop the lock early
|
||||||
|
let allocated_blocks = {
|
||||||
|
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||||
|
let clipped_required_blocks = clipped_required_blocks as usize;
|
||||||
|
|
||||||
|
if clipped_required_blocks > free_blocks.len() {
|
||||||
|
// Not enough blocks to cover this allocation
|
||||||
|
// Early return
|
||||||
|
return Err(AllocationError::NotEnoughPages);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take the blocks
|
||||||
|
let n_free_blocks = free_blocks.len();
|
||||||
|
free_blocks.split_off(n_free_blocks - clipped_required_blocks)
|
||||||
};
|
};
|
||||||
|
|
||||||
let repeats = repeats as usize;
|
let repeats = repeats as usize;
|
||||||
let required_blocks = required_blocks as usize;
|
let total_slots = total_slots as usize;
|
||||||
let clipped_required_blocks = clipped_required_blocks as usize;
|
let total_required_blocks = total_required_blocks as usize;
|
||||||
|
|
||||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
let allocated_blocks = if repeats != 1 {
|
||||||
|
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
||||||
if clipped_required_blocks > free_blocks.len() {
|
allocated_blocks.truncate(total_required_blocks);
|
||||||
Err(AllocationError::NotEnoughPages)
|
allocated_blocks
|
||||||
} else {
|
} else {
|
||||||
let n_free_blocks = free_blocks.len();
|
allocated_blocks
|
||||||
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
};
|
||||||
|
|
||||||
let allocated_blocks = if repeats != 1 {
|
let mut allocated_slots =
|
||||||
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
||||||
allocated_blocks.truncate(required_blocks);
|
|
||||||
allocated_blocks
|
|
||||||
} else {
|
|
||||||
allocated_blocks
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut allocated_slots =
|
'slots: for block_id in allocated_blocks.iter() {
|
||||||
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
allocated_slots.push(s);
|
||||||
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
if allocated_slots.len() > total_slots {
|
||||||
|
break 'slots;
|
||||||
'slots: for block_id in allocated_blocks.iter() {
|
|
||||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
|
||||||
allocated_slots.push(s);
|
|
||||||
if allocated_slots.len() > required_slots {
|
|
||||||
break 'slots;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(BlockAllocation {
|
|
||||||
allocated_blocks,
|
|
||||||
allocated_slots,
|
|
||||||
required_blocks,
|
|
||||||
required_slots,
|
|
||||||
block_allocator: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(BlockAllocation {
|
||||||
|
allocated_blocks,
|
||||||
|
allocated_slots,
|
||||||
|
required_blocks: total_required_blocks,
|
||||||
|
required_slots: total_slots,
|
||||||
|
block_allocator: self.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
|
@ -361,6 +361,7 @@ async fn decode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Filter a `batch` and remove all requests not present in `entries`
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
|
/// Ask the server to generate the full texts for entries in `terminated_entries`
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn filter_batch(
|
async fn filter_batch(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
@ -408,7 +409,10 @@ async fn filter_batch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///
|
/// Send `InferStreamResponse::Intermediate` and the final `InferStreamResponse::End` messages
|
||||||
|
/// to terminated requests
|
||||||
|
/// It modifies the last `InferStreamResponse::Intermediate` to add the final full text in
|
||||||
|
/// `terminated_generations`
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn send_terminated_generations(
|
fn send_terminated_generations(
|
||||||
terminated_generations: Vec<TerminatedGeneration>,
|
terminated_generations: Vec<TerminatedGeneration>,
|
||||||
@ -530,7 +534,7 @@ fn send_stream_responses(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if block allocations need to be extended
|
/// Check if block allocations need to be extended
|
||||||
/// If we don't have enough blocks, request will be filtered with be added to an IntMap of
|
/// If we don't have enough blocks, request will be filtered and added to an IntMap of
|
||||||
/// terminated entries.
|
/// terminated entries.
|
||||||
/// If at least one entry allocation was extended, we return true to force an update
|
/// If at least one entry allocation was extended, we return true to force an update
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
@ -592,6 +596,7 @@ fn filter_send_update_allocations(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
|
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
|
||||||
|
/// `bool` is `true` if the generation is finished
|
||||||
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
|
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
|
||||||
let mut finished = false;
|
let mut finished = false;
|
||||||
let mut stream_responses = Vec::with_capacity(16);
|
let mut stream_responses = Vec::with_capacity(16);
|
||||||
|
Loading…
Reference in New Issue
Block a user