mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Calculate token budget with padding to max_input_length (#2)
This commit is contained in:
parent
6436ae86a1
commit
b1897acfd6
@ -49,11 +49,12 @@ impl Infer {
|
||||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
window_size: Option<u32>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, 16, window_size);
|
||||
let queue = Queue::new(requires_padding, max_input_length, 16, window_size);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
@ -34,13 +34,19 @@ pub(crate) struct Queue {
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
block_size,
|
||||
window_size,
|
||||
queue_receiver,
|
||||
@ -89,11 +95,12 @@ impl Queue {
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size);
|
||||
let mut state = State::new(requires_padding, max_input_length, block_size, window_size);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
@ -131,6 +138,9 @@ struct State {
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Maximum inpult length, required for padding scenario
|
||||
max_input_length: u32,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
@ -139,12 +149,18 @@ struct State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>
|
||||
) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
block_size,
|
||||
window_size,
|
||||
}
|
||||
@ -187,7 +203,6 @@ impl State {
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
|
||||
@ -203,8 +218,7 @@ impl State {
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
|
@ -595,6 +595,7 @@ pub async fn run(
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
max_input_length as u32,
|
||||
shard_info.window_size,
|
||||
generation_health,
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user