fix queue

This commit is contained in:
OlivierDehaene 2023-04-19 16:24:18 +02:00
parent d9578153cb
commit 118f33d9dc
8 changed files with 280 additions and 183 deletions

View File

@ -3,6 +3,7 @@ use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
@ -102,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
#[derive(Debug)] #[derive(Debug)]
struct State { struct State {
/// Queue entries organized in a Vec /// Queue entries organized in a Vec
entries: Vec<(u64, Entry)>, entries: VecDeque<(u64, Entry)>,
/// Id of the next entry /// Id of the next entry
next_id: u64, next_id: u64,
@ -114,7 +115,7 @@ struct State {
impl State { impl State {
fn new() -> Self { fn new() -> Self {
Self { Self {
entries: Vec::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
} }
@ -127,7 +128,7 @@ impl State {
entry.temp_span = Some(queue_span); entry.temp_span = Some(queue_span);
// Push entry in the queue // Push entry in the queue
self.entries.push((self.next_id, entry)); self.entries.push_back((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0); metrics::increment_gauge!("tgi_queue_size", 1.0);
} }
@ -155,8 +156,8 @@ impl State {
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Drain next_batch_size entries // Iterate on buffer
for (id, mut entry) in self.entries.drain(..max_batch_size) { while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_disconnected() { if entry.response_tx.is_disconnected() {
@ -182,6 +183,16 @@ impl State {
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap // Insert in batch_entries IntMap
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
}
// Maybe all entries were dropped because their channel were closed
if batch_requests.is_empty() {
return None;
} }
// Final batch size once we dropped entries // Final batch size once we dropped entries
@ -218,15 +229,16 @@ enum QueueCommand {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::Semaphore;
use tracing::info_span; use tracing::info_span;
fn default_entry() -> Entry { fn default_entry() -> (
let (response_tx, _) = flume::unbounded(); Entry,
flume::Receiver<Result<InferStreamResponse, InferError>>,
) {
let (response_tx, receiver_tx) = flume::unbounded();
Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: "".to_string(),
truncate: 0, truncate: 0,
@ -251,13 +263,14 @@ mod tests {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
} };
(entry, receiver_tx)
} }
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(); let mut state = State::new();
let entry = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
assert_eq!(state.entries.len(), 0); assert_eq!(state.entries.len(), 0);
@ -266,7 +279,7 @@ mod tests {
assert_eq!(state.next_id, 1); assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0); let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 0); assert_eq!(id, 0);
} }
@ -281,8 +294,10 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(); let mut state = State::new();
state.append(default_entry()); let (entry1, _guard1) = default_entry();
state.append(default_entry()); let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -297,21 +312,24 @@ mod tests {
assert_eq!(state.entries.len(), 0); assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 1); assert_eq!(state.next_batch_id, 1);
state.append(default_entry()); let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none()); assert!(state.next_batch(Some(2), 2).is_none());
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0); let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 2); assert_eq!(id, 2);
} }
#[test] #[test]
fn test_next_batch_max_size() { fn test_next_batch_max_size() {
let mut state = State::new(); let mut state = State::new();
state.append(default_entry()); let (entry1, _guard1) = default_entry();
state.append(default_entry()); let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
@ -323,7 +341,8 @@ mod tests {
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1); assert_eq!(state.next_batch_id, 1);
state.append(default_entry()); let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -340,7 +359,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(); let queue = Queue::new();
queue.append(default_entry()); let (entry, _guard) = default_entry();
queue.append(entry);
} }
#[tokio::test] #[tokio::test]
@ -354,8 +374,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(); let queue = Queue::new();
queue.append(default_entry()); let (entry1, _guard1) = default_entry();
queue.append(default_entry()); let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -366,7 +388,8 @@ mod tests {
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2); assert_eq!(batch.size, 2);
queue.append(default_entry()); let (entry3, _guard3) = default_entry();
queue.append(entry3);
assert!(queue.next_batch(Some(2), 2).await.is_none()); assert!(queue.next_batch(Some(2), 2).await.is_none());
} }
@ -374,8 +397,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(); let queue = Queue::new();
queue.append(default_entry()); let (entry1, _guard1) = default_entry();
queue.append(default_entry()); let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
@ -383,7 +408,8 @@ mod tests {
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1); assert_eq!(batch.size, 1);
queue.append(default_entry()); let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -392,4 +418,13 @@ mod tests {
assert_eq!(batch.id, 1); assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2); assert_eq!(batch.size, 2);
} }
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new();
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
}
} }

View File

@ -1,6 +1,9 @@
include Makefile-transformers include Makefile-transformers
include Makefile-flash-att include Makefile-flash-att
unit-tests:
python -m pytest tests
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir

View File

@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
@pytest.fixture @pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer): def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request) req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request req_1 = default_pb_request
req_1.id = 1 req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5 req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.past_key_values is None assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1] assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11 assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert torch.all(next_batch.attention_mask[0][:2] == 1) assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -269,6 +277,8 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range( for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -290,6 +300,8 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens - default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -68,7 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.past_key_values is None assert batch.past_key_values is None
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)]) assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1] assert batch.input_lengths == [1]

View File

@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
@pytest.fixture @pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer): def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request) req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request req_1 = default_pb_request
req_1.id = 1 req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5 req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert torch.all(batch.attention_mask[0][-2:] == 1) assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0) assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1) assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None assert batch.encoder_last_hidden_state is None
@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert batch.input_lengths == [2] assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1] assert batch.decoder_input_lengths == [1]
assert batch.size == default_pb_batch.size assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0] assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
) )
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2) assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.decoder_input_ids[0, 0] == 0 assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259 assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512) assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] [p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] [p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[ [
p[2].shape == (next_batch.size, 6, sequence_length, 64) p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values for p in next_batch.past_key_values
] ]
) )
assert all( assert all(
[ [
p[3].shape == (next_batch.size, 6, sequence_length, 64) p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values for p in next_batch.past_key_values
] ]
) )
@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi(
) )
assert generations[1].generated_text.generated_tokens == 5 assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
@ -223,7 +226,8 @@ def test_batch_concatenate(
assert torch.equal( assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
) )
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0) assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal( assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
) )
@ -258,16 +262,16 @@ def test_batch_concatenate(
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
for i, past in enumerate(next_batch.past_key_values): for i, past in enumerate(next_batch.past_key_values):
@ -306,6 +310,8 @@ def test_batch_concatenate(
) )
assert generations[2].generated_text.generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
@ -314,6 +320,8 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None

View File

@ -428,7 +428,7 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], CausalLMBatch]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -545,6 +545,10 @@ class CausalLM(Model):
batch.token_offsets[i] = token_offset batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length) batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
# Slice unused values from prefill # Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1] batch.input_ids = batch.input_ids[:, :1]
@ -559,4 +563,4 @@ class CausalLM(Model):
# Update past key values # Update past key values
batch.past_key_values = past batch.past_key_values = past
return generations, batch if not stopped else None return generations, batch

View File

@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias = [] stopping_criterias = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
requests_idx_mapping = {}
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None) offsets.append(None)
@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
# Tokenize batch
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=list(all_input_ids),
input_lengths=input_lengths, input_lengths=input_lengths.tolist(),
offsets=offsets, offsets=offsets,
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, max_input_length=max_input_length.item(),
max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class Seq2SeqLMBatch(Batch): class Seq2SeqLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Encoder values # Encoder values
input_ids: torch.Tensor input_ids: torch.Tensor
@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask: Optional[torch.Tensor] decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor] encoder_last_hidden_state: Optional[torch.Tensor]
# All tokens
all_decoder_input_ids: List[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Metadata used for padding # Metadata used for padding
size: int
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
padding_right_offset: int padding_right_offset: int
@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch):
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id, requests=self.requests, size=len(self)
requests=self.requests,
size=self.size,
) )
@classmethod @classmethod
@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
requests_idx_mapping = {}
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for i, r in enumerate(pb.requests):
inputs.append(r.inputs) inputs.append(r.inputs)
# Decoder sequence only contains the bos_token requests_idx_mapping[r.id] = i
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch):
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Decoder sequence only contains the bos_token
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) decoder_input_ids = (
torch.tensor(tokenizer.bos_token_id, device=device)
.repeat(len(pb.requests))
.view(-1, 1)
)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=tokenized_inputs["input_ids"], input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"], attention_mask=tokenized_inputs["attention_mask"],
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=list(all_decoder_input_ids),
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_last_hidden_state=None, encoder_last_hidden_state=None,
past_key_values=None, past_key_values=None,
@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@tracer.start_as_current_span("filter")
def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["Seq2SeqLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
decoder_input_lengths = []
offsets = []
token_offsets = []
all_decoder_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_input_length = 0
max_decoder_input_length = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
request_decoder_input_length = self.decoder_input_lengths[idx]
decoder_input_lengths.append(request_decoder_input_length)
max_decoder_input_length = max(
max_decoder_input_length, request_decoder_input_length
)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
decoder_input_ids = self.decoder_input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
if self.decoder_attention_mask is not None:
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
else:
decoder_attention_mask = None
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
past_key_values = [
[t[keep_indices] for t in layer] for layer in self.past_key_values
]
return Seq2SeqLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=self.padding_right_offset,
)
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0 max_decoder_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length) max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max( max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length max_decoder_input_length, batch.max_decoder_input_length
@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch):
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {}
all_decoder_input_ids = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] offsets = []
@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch):
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
# Extend all list attributes # Extend all list attributes
requests.extend(batch.requests) requests.extend(batch.requests)
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch # Slicing end index for this batch
end_index = start_index + batch.size end_index = start_index + len(batch)
# We only concatenate batches that did at least one step # We only concatenate batches that did at least one step
if batch.encoder_last_hidden_state is None: if batch.encoder_last_hidden_state is None:
@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch):
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = batch.decoder_input_ids.new_zeros( decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, 1),
) )
# Copy to correct indices # Copy to correct indices
decoder_input_ids[ decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
# Create padded tensor # Create padded tensor
if decoder_attention_mask is None: if decoder_attention_mask is None:
@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch):
start_index:end_index, :, -batch.max_input_length :, : start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :] ] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size start_index += len(batch)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None, input_ids=None,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state, encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -413,46 +513,25 @@ class Seq2SeqLM(Model):
else: else:
decoder_attention_mask = None decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]` # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally... # internally...
if batch.encoder_last_hidden_state is not None: if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state] encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else: else:
encoder_last_hidden_state = batch.encoder_last_hidden_state encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward( logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
decoder_input_ids, batch.decoder_input_ids,
decoder_attention_mask, decoder_attention_mask,
encoder_last_hidden_state, encoder_last_hidden_state,
batch.past_key_values, batch.past_key_values,
) )
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -464,7 +543,7 @@ class Seq2SeqLM(Model):
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.decoder_input_ids, batch.all_decoder_input_ids,
) )
# For each member of the batch # For each member of the batch
@ -477,22 +556,24 @@ class Seq2SeqLM(Model):
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
decoder_input_ids, all_decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits all_decoder_input_ids.view(1, -1), logits
) )
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)]) all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]
)
new_decoder_input_length = decoder_input_length + 1 new_decoder_input_length = decoder_input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token( next_token_text, offset, token_offset = self.decode_token(
decoder_input_ids, offset, token_offset all_decoder_input_ids, offset, token_offset
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -501,7 +582,7 @@ class Seq2SeqLM(Model):
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(decoder_input_ids[-decoder_input_length:]) output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -515,19 +596,7 @@ class Seq2SeqLM(Model):
else: else:
# Keep request in the batch # Keep request in the batch
generated_text = None generated_text = None
next_batch_keep_indices.append(i) stopped = False
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -551,69 +620,29 @@ class Seq2SeqLM(Model):
generations.append(generation) generations.append(generation)
# Update values
batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if stopped:
return generations, None return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) # We don't need input_ids after the prefill forward
# If we finished at least one generation, we need to evict the indices of the generations that finished batch.input_ids = None
# from the values of the next batch batch.encoder_last_hidden_state = encoder_last_hidden_state
if len(next_batch_keep_indices) != len(batch): batch.past_key_values = past
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices
]
else:
next_batch_decoder_attention_mask = None
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
next_batch_keep_indices
]
next_batch_past_key_values = [
[t[next_batch_keep_indices] for t in layer] for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update decoder_attention_mask as we added a new token to input_ids # Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None: if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
next_batch = Seq2SeqLMBatch( return generations, batch
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=None,
attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask,
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
)
return generations, next_batch