add block size parameter

This commit is contained in:
OlivierDehaene 2023-07-18 12:45:51 +02:00
parent d2e3843588
commit 79616a8796
2 changed files with 29 additions and 17 deletions

View File

@ -53,7 +53,7 @@ impl Infer {
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding); let queue = Queue::new(requires_padding, 16);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });

View File

@ -33,12 +33,12 @@ pub(crate) struct Queue {
} }
impl Queue { impl Queue {
pub(crate) fn new(requires_padding: bool) -> Self { pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task(requires_padding, queue_receiver)); tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
Self { queue_sender } Self { queue_sender }
} }
@ -81,8 +81,12 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) { async fn queue_task(
let mut state = State::new(requires_padding); requires_padding: bool,
block_size: u32,
receiver: flume::Receiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size);
while let Ok(cmd) = receiver.recv_async().await { while let Ok(cmd) = receiver.recv_async().await {
match cmd { match cmd {
@ -119,15 +123,19 @@ struct State {
/// Whether the model is using padding /// Whether the model is using padding
requires_padding: bool, requires_padding: bool,
/// Paged Attention block size
block_size: u32,
} }
impl State { impl State {
fn new(requires_padding: bool) -> Self { fn new(requires_padding: bool, block_size: u32) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
block_size,
} }
} }
@ -188,7 +196,9 @@ impl State {
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
// pad to block size // pad to block size
prefill_tokens += ((entry.request.input_length + 16 - 1) / 16) * 16; prefill_tokens += ((entry.request.input_length + self.block_size - 1)
/ self.block_size)
* self.block_size;
} }
if self.requires_padding { if self.requires_padding {
@ -196,7 +206,9 @@ impl State {
} else { } else {
// pad to block size // pad to block size
decode_tokens += decode_tokens +=
((entry.request.stopping_parameters.max_new_tokens + 16 - 1) / 16) * 16; ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
/ self.block_size)
* self.block_size;
} }
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget
@ -328,7 +340,7 @@ mod tests {
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -344,7 +356,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(false); let mut state = State::new(false, 1);
assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none());
@ -352,7 +364,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -384,7 +396,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_token_budget() { fn test_next_batch_token_budget() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -417,14 +429,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
@ -432,7 +444,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -465,7 +477,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -490,7 +502,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);