diff --git a/router/src/queue.rs b/router/src/queue.rs index 93855827..c2220cb0 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -3,6 +3,7 @@ use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; +use std::collections::VecDeque; use text_generation_client::{Batch, Request}; use tokio::sync::oneshot; use tokio::time::Instant; @@ -102,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver) { #[derive(Debug)] struct State { /// Queue entries organized in a Vec - entries: Vec<(u64, Entry)>, + entries: VecDeque<(u64, Entry)>, /// Id of the next entry next_id: u64, @@ -114,7 +115,7 @@ struct State { impl State { fn new() -> Self { Self { - entries: Vec::with_capacity(128), + entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, } @@ -127,7 +128,7 @@ impl State { entry.temp_span = Some(queue_span); // Push entry in the queue - self.entries.push((self.next_id, entry)); + self.entries.push_back((self.next_id, entry)); self.next_id += 1; metrics::increment_gauge!("tgi_queue_size", 1.0); } @@ -155,8 +156,8 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default()); - // Drain next_batch_size entries - for (id, mut entry) in self.entries.drain(..max_batch_size) { + // Iterate on buffer + while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_disconnected() { @@ -182,6 +183,16 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); + + if batch_requests.len() == max_batch_size { + // We have enough requests in the batch + break; + } + } + + // Maybe all entries were dropped because their channel were closed + if batch_requests.is_empty() { + return None; } // Final batch size once we dropped entries @@ -218,15 +229,16 @@ enum QueueCommand { #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; - use tokio::sync::Semaphore; use tracing::info_span; - fn default_entry() -> Entry { - let (response_tx, _) = flume::unbounded(); + fn default_entry() -> ( + Entry, + flume::Receiver>, + ) { + let (response_tx, receiver_tx) = flume::unbounded(); - Entry { + let entry = Entry { request: ValidGenerateRequest { inputs: "".to_string(), truncate: 0, @@ -251,13 +263,14 @@ mod tests { temp_span: None, queue_time: Instant::now(), batch_time: None, - } + }; + (entry, receiver_tx) } #[test] fn test_append() { let mut state = State::new(); - let entry = default_entry(); + let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); assert_eq!(state.entries.len(), 0); @@ -266,7 +279,7 @@ mod tests { assert_eq!(state.next_id, 1); assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0); + let (id, _) = state.entries.remove(0).unwrap(); assert_eq!(id, 0); } @@ -281,8 +294,10 @@ mod tests { #[test] fn test_next_batch_min_size() { let mut state = State::new(); - state.append(default_entry()); - state.append(default_entry()); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); let (entries, batch, _) = state.next_batch(None, 2).unwrap(); assert_eq!(entries.len(), 2); @@ -297,21 +312,24 @@ mod tests { assert_eq!(state.entries.len(), 0); assert_eq!(state.next_batch_id, 1); - state.append(default_entry()); + let (entry3, _guard3) = default_entry(); + state.append(entry3); assert!(state.next_batch(Some(2), 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0); + let (id, _) = state.entries.remove(0).unwrap(); assert_eq!(id, 2); } #[test] fn test_next_batch_max_size() { let mut state = State::new(); - state.append(default_entry()); - state.append(default_entry()); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); let (entries, batch, _) = state.next_batch(None, 1).unwrap(); assert_eq!(entries.len(), 1); @@ -323,7 +341,8 @@ mod tests { assert_eq!(state.entries.len(), 1); assert_eq!(state.next_batch_id, 1); - state.append(default_entry()); + let (entry3, _guard3) = default_entry(); + state.append(entry3); let (entries, batch, _) = state.next_batch(None, 3).unwrap(); assert_eq!(entries.len(), 2); @@ -340,7 +359,8 @@ mod tests { #[tokio::test] async fn test_queue_append() { let queue = Queue::new(); - queue.append(default_entry()); + let (entry, _guard) = default_entry(); + queue.append(entry); } #[tokio::test] @@ -354,8 +374,10 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { let queue = Queue::new(); - queue.append(default_entry()); - queue.append(default_entry()); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); assert_eq!(entries.len(), 2); @@ -366,7 +388,8 @@ mod tests { assert_eq!(batch.id, 0); assert_eq!(batch.size, 2); - queue.append(default_entry()); + let (entry3, _guard3) = default_entry(); + queue.append(entry3); assert!(queue.next_batch(Some(2), 2).await.is_none()); } @@ -374,8 +397,10 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { let queue = Queue::new(); - queue.append(default_entry()); - queue.append(default_entry()); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); assert_eq!(entries.len(), 1); @@ -383,7 +408,8 @@ mod tests { assert_eq!(batch.id, 0); assert_eq!(batch.size, 1); - queue.append(default_entry()); + let (entry3, _guard3) = default_entry(); + queue.append(entry3); let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); assert_eq!(entries.len(), 2); @@ -392,4 +418,13 @@ mod tests { assert_eq!(batch.id, 1); assert_eq!(batch.size, 2); } + + #[tokio::test] + async fn test_queue_next_batch_dropped_receiver() { + let queue = Queue::new(); + let (entry, _) = default_entry(); + queue.append(entry); + + assert!(queue.next_batch(None, 1).await.is_none()); + } } diff --git a/server/Makefile b/server/Makefile index 058d912b..150d7e4a 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,6 +1,9 @@ include Makefile-transformers include Makefile-flash-att +unit-tests: + python -m pytest tests + gen-server: # Compile protos pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index a9831cd7..de0ef57b 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): @pytest.fixture def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer): req_0 = copy(default_pb_request) + req_0.id = 1 req_1 = default_pb_request - req_1.id = 1 + req_1.id = 2 req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) @@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch): assert batch.past_key_values is None - assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) + assert all( + [ + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) + ] + ) assert batch.input_lengths == [1] - assert batch.size == default_pb_batch.size - assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) assert batch.max_input_length == batch.input_lengths[0] @@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert isinstance(next_batch, CausalLMBatch) assert not next_batch.keys_head_dim_last - assert len(next_batch.all_input_ids) == next_batch.size + assert len(next_batch.all_input_ids) == len(next_batch) assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.attention_mask[0]) == 11 assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) @@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert torch.all(next_batch.attention_mask[0][:2] == 1) assert torch.all(next_batch.attention_mask[0][2:] == 0) - assert next_batch.input_ids.shape == (next_batch.size, 1) + assert next_batch.input_ids.shape == (len(next_batch), 1) assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_lengths == [2] @@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) + next_batch = next_batch.filter([next_batch.requests[0]]) + for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens @@ -269,6 +277,8 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) + next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + for _ in range( default_bloom_batch.stopping_criterias[0].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens @@ -290,6 +300,8 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) + next_batch = next_batch.filter([next_batch.requests[1]]) + for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens - default_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 7e860e07..ad79a4ca 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -68,7 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): assert batch.past_key_values is None - assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)]) + assert all( + [ + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) + ] + ) assert batch.input_lengths == [1] diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 79435787..79c9e936 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): @pytest.fixture def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer): req_0 = copy(default_pb_request) + req_0.id = 1 req_1 = default_pb_request - req_1.id = 1 + req_1.id = 2 req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) @@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): assert torch.all(batch.attention_mask[0][-2:] == 1) assert torch.all(batch.attention_mask[0][:-2] == 0) - assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1) + assert len(batch.decoder_input_ids) == default_pb_batch.size assert batch.decoder_attention_mask is None assert batch.encoder_last_hidden_state is None @@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): assert batch.input_lengths == [2] assert batch.decoder_input_lengths == [1] - assert batch.size == default_pb_batch.size - assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) assert batch.max_input_length == batch.input_lengths[0] assert batch.max_decoder_input_length == batch.decoder_input_lengths[0] @@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) ) assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias - assert next_batch.decoder_input_ids.shape == (next_batch.size, 2) - assert next_batch.decoder_input_ids[0, 0] == 0 - assert next_batch.decoder_input_ids[0, 1] == 259 + assert len(next_batch.decoder_input_ids) == len(next_batch) + assert next_batch.all_decoder_input_ids[0][0] == 0 + assert next_batch.all_decoder_input_ids[0][1] == 259 assert next_batch.decoder_attention_mask is None assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512) @@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert next_batch.past_key_values is not None assert all( - [p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] + [p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values] ) assert all( - [p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] + [p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values] ) assert all( [ - p[2].shape == (next_batch.size, 6, sequence_length, 64) + p[2].shape == (len(next_batch), 6, sequence_length, 64) for p in next_batch.past_key_values ] ) assert all( [ - p[3].shape == (next_batch.size, 6, sequence_length, 64) + p[3].shape == (len(next_batch), 6, sequence_length, 64) for p in next_batch.past_key_values ] ) @@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 + next_batch = next_batch.filter([next_batch.requests[0]]) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -223,7 +226,8 @@ def test_batch_concatenate( assert torch.equal( next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] ) - assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0) + assert next_batch.all_decoder_input_ids[1][0] == 0 + assert next_batch.all_decoder_input_ids[2][0] == 0 assert torch.equal( next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids ) @@ -258,16 +262,16 @@ def test_batch_concatenate( assert next_batch.past_key_values is not None assert all( - [p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + [p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] ) assert all( - [p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + [p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] ) assert all( - [p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + [p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] ) assert all( - [p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + [p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] ) for i, past in enumerate(next_batch.past_key_values): @@ -306,6 +310,8 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 + next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None @@ -314,6 +320,8 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 + next_batch = next_batch.filter([next_batch.requests[1]]) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 586334ed..71eff437 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -428,7 +428,7 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( self, batch: CausalLMBatch - ) -> Tuple[List[Generation], CausalLMBatch]: + ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] @@ -545,6 +545,10 @@ class CausalLM(Model): batch.token_offsets[i] = token_offset batch.max_input_length = max(batch.max_input_length, new_input_length) + # We finished all generations in the batch; there is no next batch + if stopped: + return generations, None + # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] @@ -559,4 +563,4 @@ class CausalLM(Model): # Update past key values batch.past_key_values = past - return generations, batch if not stopped else None + return generations, batch diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index dc78aa8b..746e9e83 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch): stopping_criterias = [] offsets = [] token_offsets = [] + requests_idx_mapping = {} # Parse batch max_truncation = 0 padding_right_offset = 0 - for r in pb.requests: + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) offsets.append(None) @@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch): padding_right_offset, stopping_criteria.max_new_tokens ) - # Tokenize batch tokenized_inputs = tokenizer( inputs, return_tensors="pt", @@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) return cls( batch_id=pb.id, requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - all_input_ids=all_input_ids, - input_lengths=input_lengths, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), offsets=offsets, token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - size=pb.size, - max_input_length=max_input_length, + max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 82b5def0..dd2f999b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -3,7 +3,7 @@ import torch from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type +from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__) class Seq2SeqLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] # Encoder values input_ids: torch.Tensor @@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch): decoder_attention_mask: Optional[torch.Tensor] encoder_last_hidden_state: Optional[torch.Tensor] + # All tokens + all_decoder_input_ids: List[torch.Tensor] + # Seq2SeqLM keeps track of both encoder and decoder attention keys and values past_key_values: Optional[List[Tuple]] @@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch): stopping_criterias: List[StoppingCriteria] # Metadata used for padding - size: int max_input_length: int max_decoder_input_length: int padding_right_offset: int @@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch): def to_pb(self) -> generate_pb2.Batch: """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" return generate_pb2.Batch( - id=self.batch_id, - requests=self.requests, - size=self.size, + id=self.batch_id, requests=self.requests, size=len(self) ) @classmethod @@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] - decoder_input_ids = [] decoder_input_lengths = [] offsets = [] token_offsets = [] + requests_idx_mapping = {} # Parse batch max_truncation = 0 padding_right_offset = 0 - for r in pb.requests: + for i, r in enumerate(pb.requests): inputs.append(r.inputs) - # Decoder sequence only contains the bos_token - decoder_input_ids.append(tokenizer.bos_token_id) + requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) offsets.append(None) token_offsets.append(None) @@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch): input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() - # Convert decoder_input_ids to torch tensor of size [batch_size, 1] - decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) + # Decoder sequence only contains the bos_token + decoder_input_ids = ( + torch.tensor(tokenizer.bos_token_id, device=device) + .repeat(len(pb.requests)) + .view(-1, 1) + ) + all_decoder_input_ids = decoder_input_ids.view(-1).split(1) return cls( batch_id=pb.id, requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, input_ids=tokenized_inputs["input_ids"], attention_mask=tokenized_inputs["attention_mask"], decoder_input_ids=decoder_input_ids, + all_decoder_input_ids=list(all_decoder_input_ids), decoder_attention_mask=None, encoder_last_hidden_state=None, past_key_values=None, @@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch): token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - size=len(pb.requests), max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, ) + @tracer.start_as_current_span("filter") + def filter( + self, requests: List[generate_pb2.Request] + ) -> Optional["Seq2SeqLMBatch"]: + if len(requests) == 0: + raise ValueError("Batch must have at least one request") + if len(requests) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + input_lengths = [] + decoder_input_lengths = [] + offsets = [] + token_offsets = [] + + all_decoder_input_ids = [] + + next_token_choosers = [] + stopping_criterias = [] + + max_input_length = 0 + max_decoder_input_length = 0 + + for i, r in enumerate(requests): + idx = self.requests_idx_mapping[r.id] + requests_idx_mapping[r.id] = i + keep_indices.append(idx) + + offsets.append(self.offsets[idx]) + token_offsets.append(self.token_offsets[idx]) + + all_decoder_input_ids.append(self.all_decoder_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + request_decoder_input_length = self.decoder_input_lengths[idx] + decoder_input_lengths.append(request_decoder_input_length) + max_decoder_input_length = max( + max_decoder_input_length, request_decoder_input_length + ) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criterias.append(self.stopping_criterias[idx]) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + decoder_input_ids = self.decoder_input_ids[keep_indices] + attention_mask = self.attention_mask[keep_indices] + if self.decoder_attention_mask is not None: + decoder_attention_mask = self.decoder_attention_mask[keep_indices] + else: + decoder_attention_mask = None + + encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices] + + past_key_values = [ + [t[keep_indices] for t in layer] for layer in self.past_key_values + ] + + return Seq2SeqLMBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=None, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + all_decoder_input_ids=all_decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_last_hidden_state=encoder_last_hidden_state, + past_key_values=past_key_values, + input_lengths=input_lengths, + decoder_input_lengths=decoder_input_lengths, + offsets=offsets, + token_offsets=token_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + max_decoder_input_length=max_decoder_input_length, + padding_right_offset=self.padding_right_offset, + ) + @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": @@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = 0 padding_right_offset = 0 for batch in batches: - total_batch_size += batch.size + total_batch_size += len(batch) max_input_length = max(max_input_length, batch.max_input_length) max_decoder_input_length = max( max_decoder_input_length, batch.max_decoder_input_length @@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch): # Batch attributes requests = [] + requests_idx_mapping = {} + all_decoder_input_ids = [] input_lengths = [] decoder_input_lengths = [] offsets = [] @@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch): for i, batch in enumerate(batches): # Extend all list attributes requests.extend(batch.requests) + all_decoder_input_ids.extend(batch.all_decoder_input_ids) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) offsets.extend(batch.offsets) @@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch): next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + # Slicing end index for this batch - end_index = start_index + batch.size + end_index = start_index + len(batch) # We only concatenate batches that did at least one step if batch.encoder_last_hidden_state is None: @@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch): # Create padded tensor if decoder_input_ids is None: decoder_input_ids = batch.decoder_input_ids.new_zeros( - (total_batch_size, max_decoder_input_length), + (total_batch_size, 1), ) # Copy to correct indices - decoder_input_ids[ - start_index:end_index, -batch.max_decoder_input_length : - ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] + decoder_input_ids[start_index:end_index] = batch.decoder_input_ids # Create padded tensor if decoder_attention_mask is None: @@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch): start_index:end_index, :, -batch.max_input_length :, : ] = t[:, :, -batch.max_input_length :, :] - start_index += batch.size + start_index += len(batch) return cls( batch_id=batches[0].batch_id, requests=requests, + requests_idx_mapping=requests_idx_mapping, input_ids=None, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, + all_decoder_input_ids=all_decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_last_hidden_state=encoder_last_hidden_state, past_key_values=past_key_values, @@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch): token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - size=total_batch_size, max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, @@ -413,46 +513,25 @@ class Seq2SeqLM(Model): else: decoder_attention_mask = None - # check if first forward or not - if batch.past_key_values is not None: - # Only take the last token - decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1) - else: - decoder_input_ids = batch.decoder_input_ids - # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]` # internally... if batch.encoder_last_hidden_state is not None: encoder_last_hidden_state = [batch.encoder_last_hidden_state] else: - encoder_last_hidden_state = batch.encoder_last_hidden_state + encoder_last_hidden_state = None logits, encoder_last_hidden_state, past = self.forward( batch.input_ids, batch.attention_mask, - decoder_input_ids, + batch.decoder_input_ids, decoder_attention_mask, encoder_last_hidden_state, batch.past_key_values, ) - # List of indices to cache - next_batch_keep_indices = [] - - # New values for next forward - next_batch_input_lengths = [] - next_batch_offsets = [] - next_batch_token_offsets = [] - next_batch_decoder_input_ids = [] - next_batch_decoder_input_lengths = [] - - # Metadata - next_batch_size = 0 - next_batch_max_input_length = 0 - next_batch_max_decoder_input_length = 0 - # Finished requests generations: List[Generation] = [] + stopped = True # Zipped iterator iterator = zip( @@ -464,7 +543,7 @@ class Seq2SeqLM(Model): logits, batch.next_token_choosers, batch.stopping_criterias, - batch.decoder_input_ids, + batch.all_decoder_input_ids, ) # For each member of the batch @@ -477,22 +556,24 @@ class Seq2SeqLM(Model): logits, next_token_chooser, stopping_criteria, - decoder_input_ids, + all_decoder_input_ids, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( - decoder_input_ids.view(1, -1), logits + all_decoder_input_ids.view(1, -1), logits ) # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)]) + all_decoder_input_ids = torch.cat( + [all_decoder_input_ids, next_token_id.squeeze(1)] + ) new_decoder_input_length = decoder_input_length + 1 # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() next_token_text, offset, token_offset = self.decode_token( - decoder_input_ids, offset, token_offset + all_decoder_input_ids, offset, token_offset ) # Evaluate stopping criteria @@ -501,7 +582,7 @@ class Seq2SeqLM(Model): if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode(decoder_input_ids[-decoder_input_length:]) + output_text = self.decode(all_decoder_input_ids[-decoder_input_length:]) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -515,19 +596,7 @@ class Seq2SeqLM(Model): else: # Keep request in the batch generated_text = None - next_batch_keep_indices.append(i) - next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) - next_batch_size += 1 - next_batch_input_lengths.append(input_length) - next_batch_decoder_input_lengths.append(new_decoder_input_length) - next_batch_offsets.append(offset) - next_batch_token_offsets.append(token_offset) - next_batch_max_input_length = max( - next_batch_max_input_length, input_length - ) - next_batch_max_decoder_input_length = max( - next_batch_max_decoder_input_length, new_decoder_input_length - ) + stopped = False # Prefill if stopping_criteria.current_tokens == 1: @@ -551,69 +620,29 @@ class Seq2SeqLM(Model): generations.append(generation) + # Update values + batch.decoder_input_ids[i] = next_token_id + batch.all_decoder_input_ids[i] = all_decoder_input_ids + batch.input_lengths[i] = input_length + batch.decoder_input_lengths[i] = new_decoder_input_length + batch.offsets[i] = offset + batch.token_offsets[i] = token_offset + batch.max_input_length = max(batch.max_input_length, input_length) + batch.max_decoder_input_length = max( + batch.max_decoder_input_length, new_decoder_input_length + ) + # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: + if stopped: return generations, None - next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) - # If we finished at least one generation, we need to evict the indices of the generations that finished - # from the values of the next batch - if len(next_batch_keep_indices) != len(batch): - # Apply indices to decoder_attention mask, past key values and other items that need to be cached - next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] - if batch.decoder_attention_mask is not None: - next_batch_decoder_attention_mask = batch.decoder_attention_mask[ - next_batch_keep_indices - ] - else: - next_batch_decoder_attention_mask = None - - next_batch_encoder_last_hidden_state = encoder_last_hidden_state[ - next_batch_keep_indices - ] - - next_batch_past_key_values = [ - [t[next_batch_keep_indices] for t in layer] for layer in past - ] - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_attention_mask = batch.attention_mask - next_batch_decoder_attention_mask = batch.decoder_attention_mask - next_batch_encoder_last_hidden_state = encoder_last_hidden_state - next_batch_past_key_values = past - - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - + # We don't need input_ids after the prefill forward + batch.input_ids = None + batch.encoder_last_hidden_state = encoder_last_hidden_state + batch.past_key_values = past # Update decoder_attention_mask as we added a new token to input_ids - if next_batch_decoder_attention_mask is not None: - next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1 + if batch.decoder_attention_mask is not None: + batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 + batch.padding_right_offset -= 1 - next_batch = Seq2SeqLMBatch( - batch_id=batch.batch_id, - requests=next_batch_requests, - input_ids=None, - attention_mask=next_batch_attention_mask, - decoder_input_ids=next_batch_decoder_input_ids, - decoder_attention_mask=next_batch_decoder_attention_mask, - encoder_last_hidden_state=next_batch_encoder_last_hidden_state, - past_key_values=next_batch_past_key_values, - input_lengths=next_batch_input_lengths, - decoder_input_lengths=next_batch_decoder_input_lengths, - offsets=next_batch_offsets, - token_offsets=next_batch_token_offsets, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - size=next_batch_size, - max_input_length=next_batch_max_input_length, - max_decoder_input_length=next_batch_max_decoder_input_length, - padding_right_offset=batch.padding_right_offset - 1, - ) - return generations, next_batch + return generations, batch