mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Port over block allocator interface (with token ids)
This commit is contained in:
parent
4562c16048
commit
5d482d4da2
@ -1,16 +1,18 @@
|
||||
use std::cmp::min;
|
||||
use std::{cmp::min, sync::Arc};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
pub allocation_id: u64,
|
||||
block_allocator: BlockAllocator,
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
self.block_allocator.free(self.blocks.clone())
|
||||
self.block_allocator
|
||||
.free(self.blocks.clone(), self.allocation_id)
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,11 +44,16 @@ impl BlockAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||
pub(crate) async fn allocate(
|
||||
&self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
@ -54,16 +61,20 @@ impl BlockAllocator {
|
||||
response_receiver
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|(blocks, slots)| BlockAllocation {
|
||||
.map(|(blocks, slots, allocation_id)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
allocation_id,
|
||||
block_allocator: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free { blocks })
|
||||
.send(BlockAllocatorCommand::Free {
|
||||
allocation_id,
|
||||
blocks,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@ -74,63 +85,111 @@ async fn block_allocator_task(
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
// Block 0 is reserved for health checks
|
||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||
let mut allocator = SimpleAllocator::new(blocks, block_size, window_size);
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||
BlockAllocatorCommand::Free {
|
||||
blocks,
|
||||
allocation_id,
|
||||
} => allocator.free(blocks, allocation_id),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks =
|
||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots))
|
||||
};
|
||||
response_sender.send(allocation).unwrap();
|
||||
let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
|
||||
response_sender
|
||||
.send(allocator.allocate(tokens, prefill_tokens_slice))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<&[u32]>,
|
||||
) -> Option<(Vec<u32>, Vec<u32>, u64)>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<&[u32]>,
|
||||
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots, 0))
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
allocation_id: u64,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>, u64)>>,
|
||||
},
|
||||
}
|
||||
|
@ -298,7 +298,10 @@ impl State {
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
match block_allocator.allocate(tokens).await {
|
||||
match block_allocator
|
||||
.allocate(tokens, entry.request.input_ids.clone())
|
||||
.await
|
||||
{
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
|
@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc;
|
||||
@ -121,7 +122,7 @@ impl Validation {
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
@ -156,8 +157,10 @@ impl Validation {
|
||||
));
|
||||
}
|
||||
|
||||
let input_ids = encoding.get_ids()[..input_length].to_owned();
|
||||
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
@ -180,7 +183,12 @@ impl Validation {
|
||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
}
|
||||
|
||||
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
||||
Ok((
|
||||
vec![Chunk::Text(inputs)],
|
||||
None,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -314,7 +322,7 @@ impl Validation {
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_length, max_new_tokens) = self
|
||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
@ -391,6 +399,7 @@ impl Validation {
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
input_ids: input_ids.map(Arc::new),
|
||||
decoder_input_details,
|
||||
input_length: input_length as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
@ -707,6 +716,7 @@ pub struct ValidStoppingParameters {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ValidGenerateRequest {
|
||||
pub inputs: Vec<Chunk>,
|
||||
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
|
Loading…
Reference in New Issue
Block a user