mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
fix queue
This commit is contained in:
parent
d9578153cb
commit
118f33d9dc
@ -3,6 +3,7 @@ use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
@ -102,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Queue entries organized in a Vec
|
||||
entries: Vec<(u64, Entry)>,
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
@ -114,7 +115,7 @@ struct State {
|
||||
impl State {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
entries: Vec::with_capacity(128),
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
}
|
||||
@ -127,7 +128,7 @@ impl State {
|
||||
entry.temp_span = Some(queue_span);
|
||||
|
||||
// Push entry in the queue
|
||||
self.entries.push((self.next_id, entry));
|
||||
self.entries.push_back((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
}
|
||||
@ -155,8 +156,8 @@ impl State {
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
||||
|
||||
// Drain next_batch_size entries
|
||||
for (id, mut entry) in self.entries.drain(..max_batch_size) {
|
||||
// Iterate on buffer
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_disconnected() {
|
||||
@ -182,6 +183,16 @@ impl State {
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
if batch_requests.len() == max_batch_size {
|
||||
// We have enough requests in the batch
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Maybe all entries were dropped because their channel were closed
|
||||
if batch_requests.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Final batch size once we dropped entries
|
||||
@ -218,15 +229,16 @@ enum QueueCommand {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use tokio::sync::Semaphore;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> Entry {
|
||||
let (response_tx, _) = flume::unbounded();
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
flume::Receiver<Result<InferStreamResponse, InferError>>,
|
||||
) {
|
||||
let (response_tx, receiver_tx) = flume::unbounded();
|
||||
|
||||
Entry {
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: "".to_string(),
|
||||
truncate: 0,
|
||||
@ -251,13 +263,14 @@ mod tests {
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
}
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new();
|
||||
let entry = default_entry();
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
@ -266,7 +279,7 @@ mod tests {
|
||||
|
||||
assert_eq!(state.next_id, 1);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
@ -281,8 +294,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new();
|
||||
state.append(default_entry());
|
||||
state.append(default_entry());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
@ -297,21 +312,24 @@ mod tests {
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
state.append(default_entry());
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), 2).is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new();
|
||||
state.append(default_entry());
|
||||
state.append(default_entry());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
@ -323,7 +341,8 @@ mod tests {
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
state.append(default_entry());
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
@ -340,7 +359,8 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new();
|
||||
queue.append(default_entry());
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -354,8 +374,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new();
|
||||
queue.append(default_entry());
|
||||
queue.append(default_entry());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
@ -366,7 +388,8 @@ mod tests {
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
queue.append(default_entry());
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
||||
}
|
||||
@ -374,8 +397,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new();
|
||||
queue.append(default_entry());
|
||||
queue.append(default_entry());
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
@ -383,7 +408,8 @@ mod tests {
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
queue.append(default_entry());
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
@ -392,4 +418,13 @@ mod tests {
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new();
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
assert!(queue.next_batch(None, 1).await.is_none());
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
include Makefile-transformers
|
||||
include Makefile-flash-att
|
||||
|
||||
unit-tests:
|
||||
python -m pytest tests
|
||||
|
||||
gen-server:
|
||||
# Compile protos
|
||||
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
|
||||
|
@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
||||
@pytest.fixture
|
||||
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
|
||||
req_0 = copy(default_pb_request)
|
||||
req_0.id = 1
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.id = 2
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
|
||||
assert all(
|
||||
[
|
||||
torch.equal(input_ids, all_input_ids[:, 0])
|
||||
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
||||
]
|
||||
)
|
||||
|
||||
assert batch.input_lengths == [1]
|
||||
|
||||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
assert len(batch) == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
|
||||
@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
assert not next_batch.keys_head_dim_last
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
assert len(next_batch.all_input_ids) == len(next_batch)
|
||||
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
||||
assert len(next_batch.attention_mask[0]) == 11
|
||||
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
||||
@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
assert torch.all(next_batch.attention_mask[0][:2] == 1)
|
||||
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
||||
|
||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||
assert next_batch.input_ids.shape == (len(next_batch), 1)
|
||||
assert next_batch.input_ids[0, 0] == 10264
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
@ -269,6 +277,8 @@ def test_batch_concatenate(
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||
|
||||
for _ in range(
|
||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
@ -290,6 +300,8 @@ def test_batch_concatenate(
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -68,7 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)])
|
||||
assert all(
|
||||
[
|
||||
torch.equal(input_ids, all_input_ids[:, 0])
|
||||
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
||||
]
|
||||
)
|
||||
|
||||
assert batch.input_lengths == [1]
|
||||
|
||||
|
@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
||||
@pytest.fixture
|
||||
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
|
||||
req_0 = copy(default_pb_request)
|
||||
req_0.id = 1
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.id = 2
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||
assert torch.all(batch.attention_mask[0][-2:] == 1)
|
||||
assert torch.all(batch.attention_mask[0][:-2] == 0)
|
||||
|
||||
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
|
||||
assert len(batch.decoder_input_ids) == default_pb_batch.size
|
||||
assert batch.decoder_attention_mask is None
|
||||
assert batch.encoder_last_hidden_state is None
|
||||
|
||||
@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||
assert batch.input_lengths == [2]
|
||||
assert batch.decoder_input_lengths == [1]
|
||||
|
||||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
assert len(batch) == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
|
||||
@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||
)
|
||||
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
|
||||
|
||||
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
|
||||
assert next_batch.decoder_input_ids[0, 0] == 0
|
||||
assert next_batch.decoder_input_ids[0, 1] == 259
|
||||
assert len(next_batch.decoder_input_ids) == len(next_batch)
|
||||
assert next_batch.all_decoder_input_ids[0][0] == 0
|
||||
assert next_batch.all_decoder_input_ids[0][1] == 259
|
||||
assert next_batch.decoder_attention_mask is None
|
||||
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
|
||||
|
||||
@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all(
|
||||
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
||||
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
||||
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
p[2].shape == (next_batch.size, 6, sequence_length, 64)
|
||||
p[2].shape == (len(next_batch), 6, sequence_length, 64)
|
||||
for p in next_batch.past_key_values
|
||||
]
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
p[3].shape == (next_batch.size, 6, sequence_length, 64)
|
||||
p[3].shape == (len(next_batch), 6, sequence_length, 64)
|
||||
for p in next_batch.past_key_values
|
||||
]
|
||||
)
|
||||
@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||
)
|
||||
assert generations[1].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
@ -223,7 +226,8 @@ def test_batch_concatenate(
|
||||
assert torch.equal(
|
||||
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
|
||||
)
|
||||
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
|
||||
assert next_batch.all_decoder_input_ids[1][0] == 0
|
||||
assert next_batch.all_decoder_input_ids[2][0] == 0
|
||||
assert torch.equal(
|
||||
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
|
||||
)
|
||||
@ -258,16 +262,16 @@ def test_batch_concatenate(
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all(
|
||||
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
@ -306,6 +310,8 @@ def test_batch_concatenate(
|
||||
)
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
@ -314,6 +320,8 @@ def test_batch_concatenate(
|
||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
|
@ -428,7 +428,7 @@ class CausalLM(Model):
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[Generation], CausalLMBatch]:
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
@ -545,6 +545,10 @@ class CausalLM(Model):
|
||||
batch.token_offsets[i] = token_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
return generations, None
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
@ -559,4 +563,4 @@ class CausalLM(Model):
|
||||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
|
||||
return generations, batch if not stopped else None
|
||||
return generations, batch
|
||||
|
@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
stopping_criterias = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
offsets.append(None)
|
||||
@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# Tokenize batch
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=list(all_input_ids),
|
||||
input_lengths=input_lengths.tolist(),
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
max_input_length=max_input_length,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
|
||||
class Seq2SeqLMBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
# Encoder values
|
||||
input_ids: torch.Tensor
|
||||
@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
decoder_attention_mask: Optional[torch.Tensor]
|
||||
encoder_last_hidden_state: Optional[torch.Tensor]
|
||||
|
||||
# All tokens
|
||||
all_decoder_input_ids: List[torch.Tensor]
|
||||
|
||||
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
|
||||
past_key_values: Optional[List[Tuple]]
|
||||
|
||||
@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Metadata used for padding
|
||||
size: int
|
||||
max_input_length: int
|
||||
max_decoder_input_length: int
|
||||
padding_right_offset: int
|
||||
@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
size=self.size,
|
||||
id=self.batch_id, requests=self.requests, size=len(self)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
decoder_input_ids = []
|
||||
decoder_input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
for i, r in enumerate(pb.requests):
|
||||
inputs.append(r.inputs)
|
||||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
offsets.append(None)
|
||||
token_offsets.append(None)
|
||||
@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch):
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids = (
|
||||
torch.tensor(tokenizer.bos_token_id, device=device)
|
||||
.repeat(len(pb.requests))
|
||||
.view(-1, 1)
|
||||
)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=tokenized_inputs["input_ids"],
|
||||
attention_mask=tokenized_inputs["attention_mask"],
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
all_decoder_input_ids=list(all_decoder_input_ids),
|
||||
decoder_attention_mask=None,
|
||||
encoder_last_hidden_state=None,
|
||||
past_key_values=None,
|
||||
@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch):
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=len(pb.requests),
|
||||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, requests: List[generate_pb2.Request]
|
||||
) -> Optional["Seq2SeqLMBatch"]:
|
||||
if len(requests) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(requests) == len(self):
|
||||
return self
|
||||
|
||||
keep_indices = []
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
input_lengths = []
|
||||
decoder_input_lengths = []
|
||||
offsets = []
|
||||
token_offsets = []
|
||||
|
||||
all_decoder_input_ids = []
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
max_input_length = 0
|
||||
max_decoder_input_length = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
keep_indices.append(idx)
|
||||
|
||||
offsets.append(self.offsets[idx])
|
||||
token_offsets.append(self.token_offsets[idx])
|
||||
|
||||
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
|
||||
|
||||
request_input_length = self.input_lengths[idx]
|
||||
input_lengths.append(request_input_length)
|
||||
max_input_length = max(max_input_length, request_input_length)
|
||||
|
||||
request_decoder_input_length = self.decoder_input_lengths[idx]
|
||||
decoder_input_lengths.append(request_decoder_input_length)
|
||||
max_decoder_input_length = max(
|
||||
max_decoder_input_length, request_decoder_input_length
|
||||
)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criterias.append(self.stopping_criterias[idx])
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||
attention_mask = self.attention_mask[keep_indices]
|
||||
if self.decoder_attention_mask is not None:
|
||||
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
|
||||
else:
|
||||
decoder_attention_mask = None
|
||||
|
||||
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
|
||||
|
||||
past_key_values = [
|
||||
[t[keep_indices] for t in layer] for layer in self.past_key_values
|
||||
]
|
||||
|
||||
return Seq2SeqLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
all_decoder_input_ids=all_decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_last_hidden_state=encoder_last_hidden_state,
|
||||
past_key_values=past_key_values,
|
||||
input_lengths=input_lengths,
|
||||
decoder_input_lengths=decoder_input_lengths,
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=self.padding_right_offset,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||
@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
max_decoder_input_length = 0
|
||||
padding_right_offset = 0
|
||||
for batch in batches:
|
||||
total_batch_size += batch.size
|
||||
total_batch_size += len(batch)
|
||||
max_input_length = max(max_input_length, batch.max_input_length)
|
||||
max_decoder_input_length = max(
|
||||
max_decoder_input_length, batch.max_decoder_input_length
|
||||
@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
all_decoder_input_ids = []
|
||||
input_lengths = []
|
||||
decoder_input_lengths = []
|
||||
offsets = []
|
||||
@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
for i, batch in enumerate(batches):
|
||||
# Extend all list attributes
|
||||
requests.extend(batch.requests)
|
||||
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||
offsets.extend(batch.offsets)
|
||||
@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + start_index
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + batch.size
|
||||
end_index = start_index + len(batch)
|
||||
|
||||
# We only concatenate batches that did at least one step
|
||||
if batch.encoder_last_hidden_state is None:
|
||||
@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch):
|
||||
# Create padded tensor
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = batch.decoder_input_ids.new_zeros(
|
||||
(total_batch_size, max_decoder_input_length),
|
||||
(total_batch_size, 1),
|
||||
)
|
||||
# Copy to correct indices
|
||||
decoder_input_ids[
|
||||
start_index:end_index, -batch.max_decoder_input_length :
|
||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
||||
decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
|
||||
|
||||
# Create padded tensor
|
||||
if decoder_attention_mask is None:
|
||||
@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch):
|
||||
start_index:end_index, :, -batch.max_input_length :, :
|
||||
] = t[:, :, -batch.max_input_length :, :]
|
||||
|
||||
start_index += batch.size
|
||||
start_index += len(batch)
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
all_decoder_input_ids=all_decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_last_hidden_state=encoder_last_hidden_state,
|
||||
past_key_values=past_key_values,
|
||||
@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch):
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -413,46 +513,25 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
decoder_attention_mask = None
|
||||
|
||||
# check if first forward or not
|
||||
if batch.past_key_values is not None:
|
||||
# Only take the last token
|
||||
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
decoder_input_ids = batch.decoder_input_ids
|
||||
|
||||
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
||||
# internally...
|
||||
if batch.encoder_last_hidden_state is not None:
|
||||
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
|
||||
else:
|
||||
encoder_last_hidden_state = batch.encoder_last_hidden_state
|
||||
encoder_last_hidden_state = None
|
||||
|
||||
logits, encoder_last_hidden_state, past = self.forward(
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
decoder_input_ids,
|
||||
batch.decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_last_hidden_state,
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
# List of indices to cache
|
||||
next_batch_keep_indices = []
|
||||
|
||||
# New values for next forward
|
||||
next_batch_input_lengths = []
|
||||
next_batch_offsets = []
|
||||
next_batch_token_offsets = []
|
||||
next_batch_decoder_input_ids = []
|
||||
next_batch_decoder_input_lengths = []
|
||||
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
next_batch_max_input_length = 0
|
||||
next_batch_max_decoder_input_length = 0
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -464,7 +543,7 @@ class Seq2SeqLM(Model):
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.decoder_input_ids,
|
||||
batch.all_decoder_input_ids,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -477,22 +556,24 @@ class Seq2SeqLM(Model):
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
decoder_input_ids,
|
||||
all_decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
decoder_input_ids.view(1, -1), logits
|
||||
all_decoder_input_ids.view(1, -1), logits
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
|
||||
all_decoder_input_ids = torch.cat(
|
||||
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
||||
)
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, offset, token_offset = self.decode_token(
|
||||
decoder_input_ids, offset, token_offset
|
||||
all_decoder_input_ids, offset, token_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
@ -501,7 +582,7 @@ class Seq2SeqLM(Model):
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(decoder_input_ids[-decoder_input_length:])
|
||||
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
@ -515,19 +596,7 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_offsets.append(offset)
|
||||
next_batch_token_offsets.append(token_offset)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
next_batch_max_decoder_input_length = max(
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
stopped = False
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
@ -551,69 +620,29 @@ class Seq2SeqLM(Model):
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.decoder_input_ids[i] = next_token_id
|
||||
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||
batch.input_lengths[i] = input_length
|
||||
batch.decoder_input_lengths[i] = new_decoder_input_length
|
||||
batch.offsets[i] = offset
|
||||
batch.token_offsets[i] = token_offset
|
||||
batch.max_input_length = max(batch.max_input_length, input_length)
|
||||
batch.max_decoder_input_length = max(
|
||||
batch.max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
if stopped:
|
||||
return generations, None
|
||||
|
||||
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||
# from the values of the next batch
|
||||
if len(next_batch_keep_indices) != len(batch):
|
||||
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
|
||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||
if batch.decoder_attention_mask is not None:
|
||||
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
|
||||
next_batch_keep_indices
|
||||
]
|
||||
else:
|
||||
next_batch_decoder_attention_mask = None
|
||||
|
||||
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
|
||||
next_batch_keep_indices
|
||||
]
|
||||
|
||||
next_batch_past_key_values = [
|
||||
[t[next_batch_keep_indices] for t in layer] for layer in past
|
||||
]
|
||||
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
||||
next_batch_next_token_choosers = [
|
||||
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
||||
]
|
||||
next_batch_stopping_criterias = [
|
||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
||||
]
|
||||
else:
|
||||
next_batch_attention_mask = batch.attention_mask
|
||||
next_batch_decoder_attention_mask = batch.decoder_attention_mask
|
||||
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
|
||||
next_batch_past_key_values = past
|
||||
|
||||
next_batch_requests = batch.requests
|
||||
next_batch_next_token_choosers = batch.next_token_choosers
|
||||
next_batch_stopping_criterias = batch.stopping_criterias
|
||||
|
||||
# We don't need input_ids after the prefill forward
|
||||
batch.input_ids = None
|
||||
batch.encoder_last_hidden_state = encoder_last_hidden_state
|
||||
batch.past_key_values = past
|
||||
# Update decoder_attention_mask as we added a new token to input_ids
|
||||
if next_batch_decoder_attention_mask is not None:
|
||||
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||
if batch.decoder_attention_mask is not None:
|
||||
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
next_batch = Seq2SeqLMBatch(
|
||||
batch_id=batch.batch_id,
|
||||
requests=next_batch_requests,
|
||||
input_ids=None,
|
||||
attention_mask=next_batch_attention_mask,
|
||||
decoder_input_ids=next_batch_decoder_input_ids,
|
||||
decoder_attention_mask=next_batch_decoder_attention_mask,
|
||||
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
|
||||
past_key_values=next_batch_past_key_values,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||
offsets=next_batch_offsets,
|
||||
token_offsets=next_batch_token_offsets,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_input_length=next_batch_max_input_length,
|
||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||
padding_right_offset=batch.padding_right_offset - 1,
|
||||
)
|
||||
return generations, next_batch
|
||||
return generations, batch
|
||||
|
Loading…
Reference in New Issue
Block a user