Calculate token budget with padding to max_input_length (#2)

This commit is contained in:
Karol Damaszke 2023-12-11 09:24:27 +01:00 committed by GitHub
parent 6436ae86a1
commit b1897acfd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 7 deletions

View File

@ -49,11 +49,12 @@ impl Infer {
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
max_input_length: u32,
window_size: Option<u32>, window_size: Option<u32>,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, 16, window_size); let queue = Queue::new(requires_padding, max_input_length, 16, window_size);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });

View File

@ -34,13 +34,19 @@ pub(crate) struct Queue {
} }
impl Queue { impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self { pub(crate) fn new(
requires_padding: bool,
max_input_length: u32,
block_size: u32,
window_size: Option<u32>
) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task( tokio::spawn(queue_task(
requires_padding, requires_padding,
max_input_length,
block_size, block_size,
window_size, window_size,
queue_receiver, queue_receiver,
@ -89,11 +95,12 @@ impl Queue {
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
max_input_length: u32,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size, window_size); let mut state = State::new(requires_padding, max_input_length, block_size, window_size);
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
@ -131,6 +138,9 @@ struct State {
/// Whether the model is using padding /// Whether the model is using padding
requires_padding: bool, requires_padding: bool,
/// Maximum inpult length, required for padding scenario
max_input_length: u32,
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
@ -139,12 +149,18 @@ struct State {
} }
impl State { impl State {
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self { fn new(
requires_padding: bool,
max_input_length: u32,
block_size: u32,
window_size: Option<u32>
) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
max_input_length,
block_size, block_size,
window_size, window_size,
} }
@ -187,7 +203,6 @@ impl State {
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0; let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
@ -203,8 +218,7 @@ impl State {
if self.requires_padding { if self.requires_padding {
// We pad to max input length in the Python shards // We pad to max input length in the Python shards
// We need to take these padding tokens into the equation // We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length); prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
// pad to block size // pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1) prefill_tokens += ((entry.request.input_length + self.block_size - 1)

View File

@ -595,6 +595,7 @@ pub async fn run(
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
max_input_length as u32,
shard_info.window_size, shard_info.window_size,
generation_health, generation_health,
); );