mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add block size parameter
This commit is contained in:
parent
d2e3843588
commit
79616a8796
@ -53,7 +53,7 @@ impl Infer {
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding);
|
||||
let queue = Queue::new(requires_padding, 16);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
@ -33,12 +33,12 @@ pub(crate) struct Queue {
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(requires_padding: bool) -> Self {
|
||||
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(requires_padding, queue_receiver));
|
||||
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
@ -81,8 +81,12 @@ impl Queue {
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
|
||||
let mut state = State::new(requires_padding);
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
receiver: flume::Receiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size);
|
||||
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
@ -119,15 +123,19 @@ struct State {
|
||||
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(requires_padding: bool) -> Self {
|
||||
fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,7 +196,9 @@ impl State {
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + 16 - 1) / 16) * 16;
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
@ -196,7 +206,9 @@ impl State {
|
||||
} else {
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
((entry.request.stopping_parameters.max_new_tokens + 16 - 1) / 16) * 16;
|
||||
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
@ -328,7 +340,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -344,7 +356,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
|
||||
assert!(state.next_batch(None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||
@ -352,7 +364,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -384,7 +396,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -417,14 +429,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||
@ -432,7 +444,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -465,7 +477,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -490,7 +502,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user