fix queue

This commit is contained in:
OlivierDehaene 2023-04-19 16:24:18 +02:00
parent d9578153cb
commit 118f33d9dc
8 changed files with 280 additions and 183 deletions

View File

@ -3,6 +3,7 @@ use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::{Batch, Request};
use tokio::sync::oneshot;
use tokio::time::Instant;
@ -102,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
#[derive(Debug)]
struct State {
/// Queue entries organized in a Vec
entries: Vec<(u64, Entry)>,
entries: VecDeque<(u64, Entry)>,
/// Id of the next entry
next_id: u64,
@ -114,7 +115,7 @@ struct State {
impl State {
fn new() -> Self {
Self {
entries: Vec::with_capacity(128),
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
}
@ -127,7 +128,7 @@ impl State {
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push((self.next_id, entry));
self.entries.push_back((self.next_id, entry));
self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
@ -155,8 +156,8 @@ impl State {
let mut batch_entries =
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Drain next_batch_size entries
for (id, mut entry) in self.entries.drain(..max_batch_size) {
// Iterate on buffer
while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_disconnected() {
@ -182,6 +183,16 @@ impl State {
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
}
// Maybe all entries were dropped because their channel were closed
if batch_requests.is_empty() {
return None;
}
// Final batch size once we dropped entries
@ -218,15 +229,16 @@ enum QueueCommand {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::Semaphore;
use tracing::info_span;
fn default_entry() -> Entry {
let (response_tx, _) = flume::unbounded();
fn default_entry() -> (
Entry,
flume::Receiver<Result<InferStreamResponse, InferError>>,
) {
let (response_tx, receiver_tx) = flume::unbounded();
Entry {
let entry = Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
truncate: 0,
@ -251,13 +263,14 @@ mod tests {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
}
};
(entry, receiver_tx)
}
#[test]
fn test_append() {
let mut state = State::new();
let entry = default_entry();
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
assert_eq!(state.entries.len(), 0);
@ -266,7 +279,7 @@ mod tests {
assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 0);
}
@ -281,8 +294,10 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new();
state.append(default_entry());
state.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2);
@ -297,21 +312,24 @@ mod tests {
assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 1);
state.append(default_entry());
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 2);
}
#[test]
fn test_next_batch_max_size() {
let mut state = State::new();
state.append(default_entry());
state.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1);
@ -323,7 +341,8 @@ mod tests {
assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1);
state.append(default_entry());
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2);
@ -340,7 +359,8 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new();
queue.append(default_entry());
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
@ -354,8 +374,10 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new();
queue.append(default_entry());
queue.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2);
@ -366,7 +388,8 @@ mod tests {
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
queue.append(default_entry());
let (entry3, _guard3) = default_entry();
queue.append(entry3);
assert!(queue.next_batch(Some(2), 2).await.is_none());
}
@ -374,8 +397,10 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new();
queue.append(default_entry());
queue.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1);
@ -383,7 +408,8 @@ mod tests {
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
queue.append(default_entry());
let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2);
@ -392,4 +418,13 @@ mod tests {
assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new();
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
}
}

View File

@ -1,6 +1,9 @@
include Makefile-transformers
include Makefile-flash-att
unit-tests:
python -m pytest tests
gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir

View File

@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 1
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -269,6 +277,8 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -290,6 +300,8 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -68,7 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.past_key_values is None
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)])
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]

View File

@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 1
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
)
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
assert next_batch.decoder_input_ids[0, 0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259
assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[
p[2].shape == (next_batch.size, 6, sequence_length, 64)
p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all(
[
p[3].shape == (next_batch.size, 6, sequence_length, 64)
p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi(
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
@ -223,7 +226,8 @@ def test_batch_concatenate(
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
@ -258,16 +262,16 @@ def test_batch_concatenate(
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
for i, past in enumerate(next_batch.past_key_values):
@ -306,6 +310,8 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
@ -314,6 +320,8 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None

View File

@ -428,7 +428,7 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], CausalLMBatch]:
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -545,6 +545,10 @@ class CausalLM(Model):
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
@ -559,4 +563,4 @@ class CausalLM(Model):
# Update past key values
batch.past_key_values = past
return generations, batch if not stopped else None
return generations, batch

View File

@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None)
@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset, stopping_criteria.max_new_tokens
)
# Tokenize batch
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_input_length=max_input_length,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
)

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class Seq2SeqLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Encoder values
input_ids: torch.Tensor
@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor]
# All tokens
all_decoder_input_ids: List[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]]
@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_input_length: int
max_decoder_input_length: int
padding_right_offset: int
@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch):
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
id=self.batch_id, requests=self.requests, size=len(self)
)
@classmethod
@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
decoder_input_ids = []
decoder_input_lengths = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
for i, r in enumerate(pb.requests):
inputs.append(r.inputs)
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
offsets.append(None)
token_offsets.append(None)
@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch):
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
# Decoder sequence only contains the bos_token
decoder_input_ids = (
torch.tensor(tokenizer.bos_token_id, device=device)
.repeat(len(pb.requests))
.view(-1, 1)
)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=list(all_decoder_input_ids),
decoder_attention_mask=None,
encoder_last_hidden_state=None,
past_key_values=None,
@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
)
@tracer.start_as_current_span("filter")
def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["Seq2SeqLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
decoder_input_lengths = []
offsets = []
token_offsets = []
all_decoder_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_input_length = 0
max_decoder_input_length = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
request_decoder_input_length = self.decoder_input_lengths[idx]
decoder_input_lengths.append(request_decoder_input_length)
max_decoder_input_length = max(
max_decoder_input_length, request_decoder_input_length
)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
decoder_input_ids = self.decoder_input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
if self.decoder_attention_mask is not None:
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
else:
decoder_attention_mask = None
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
past_key_values = [
[t[keep_indices] for t in layer] for layer in self.past_key_values
]
return Seq2SeqLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=self.padding_right_offset,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length
@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch):
# Batch attributes
requests = []
requests_idx_mapping = {}
all_decoder_input_ids = []
input_lengths = []
decoder_input_lengths = []
offsets = []
@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch):
for i, batch in enumerate(batches):
# Extend all list attributes
requests.extend(batch.requests)
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets)
@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + batch.size
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.encoder_last_hidden_state is None:
@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch):
# Create padded tensor
if decoder_input_ids is None:
decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, max_decoder_input_length),
(total_batch_size, 1),
)
# Copy to correct indices
decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
# Create padded tensor
if decoder_attention_mask is None:
@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch):
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size
start_index += len(batch)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@ -413,46 +513,25 @@ class Seq2SeqLM(Model):
else:
decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else:
encoder_last_hidden_state = batch.encoder_last_hidden_state
encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
decoder_input_ids,
batch.decoder_input_ids,
decoder_attention_mask,
encoder_last_hidden_state,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0
# Finished requests
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
@ -464,7 +543,7 @@ class Seq2SeqLM(Model):
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.decoder_input_ids,
batch.all_decoder_input_ids,
)
# For each member of the batch
@ -477,22 +556,24 @@ class Seq2SeqLM(Model):
logits,
next_token_chooser,
stopping_criteria,
decoder_input_ids,
all_decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
all_decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]
)
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token(
decoder_input_ids, offset, token_offset
all_decoder_input_ids, offset, token_offset
)
# Evaluate stopping criteria
@ -501,7 +582,7 @@ class Seq2SeqLM(Model):
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(decoder_input_ids[-decoder_input_length:])
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -515,19 +596,7 @@ class Seq2SeqLM(Model):
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
stopped = False
# Prefill
if stopping_criteria.current_tokens == 1:
@ -551,69 +620,29 @@ class Seq2SeqLM(Model):
generations.append(generation)
# Update values
batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
if stopped:
return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices
]
else:
next_batch_decoder_attention_mask = None
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
next_batch_keep_indices
]
next_batch_past_key_values = [
[t[next_batch_keep_indices] for t in layer] for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# We don't need input_ids after the prefill forward
batch.input_ids = None
batch.encoder_last_hidden_state = encoder_last_hidden_state
batch.past_key_values = past
# Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
if batch.decoder_attention_mask is not None:
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
next_batch = Seq2SeqLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=None,
attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask,
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
)
return generations, next_batch
return generations, batch