mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Calculate token budget with padding to max_input_length (#2)
This commit is contained in:
parent
6436ae86a1
commit
b1897acfd6
@ -49,11 +49,12 @@ impl Infer {
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
max_input_length: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16, window_size);
|
let queue = Queue::new(requires_padding, max_input_length, 16, window_size);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
@ -34,13 +34,19 @@ pub(crate) struct Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
pub(crate) fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
max_input_length: u32,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>
|
||||||
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(
|
tokio::spawn(queue_task(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
max_input_length,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
@ -89,11 +95,12 @@ impl Queue {
|
|||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
max_input_length: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size, window_size);
|
let mut state = State::new(requires_padding, max_input_length, block_size, window_size);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -131,6 +138,9 @@ struct State {
|
|||||||
/// Whether the model is using padding
|
/// Whether the model is using padding
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
|
||||||
|
/// Maximum inpult length, required for padding scenario
|
||||||
|
max_input_length: u32,
|
||||||
|
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
@ -139,12 +149,18 @@ struct State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
max_input_length: u32,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
max_input_length,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
}
|
}
|
||||||
@ -187,7 +203,6 @@ impl State {
|
|||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
|
|
||||||
@ -203,8 +218,7 @@ impl State {
|
|||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
|
||||||
} else {
|
} else {
|
||||||
// pad to block size
|
// pad to block size
|
||||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||||
|
@ -595,6 +595,7 @@ pub async fn run(
|
|||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
|
max_input_length as u32,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
generation_health,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user