mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix queue
This commit is contained in:
parent
d9578153cb
commit
118f33d9dc
@ -3,6 +3,7 @@ use crate::infer::InferStreamResponse;
|
|||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -102,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct State {
|
struct State {
|
||||||
/// Queue entries organized in a Vec
|
/// Queue entries organized in a Vec
|
||||||
entries: Vec<(u64, Entry)>,
|
entries: VecDeque<(u64, Entry)>,
|
||||||
|
|
||||||
/// Id of the next entry
|
/// Id of the next entry
|
||||||
next_id: u64,
|
next_id: u64,
|
||||||
@ -114,7 +115,7 @@ struct State {
|
|||||||
impl State {
|
impl State {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: Vec::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
}
|
}
|
||||||
@ -127,7 +128,7 @@ impl State {
|
|||||||
entry.temp_span = Some(queue_span);
|
entry.temp_span = Some(queue_span);
|
||||||
|
|
||||||
// Push entry in the queue
|
// Push entry in the queue
|
||||||
self.entries.push((self.next_id, entry));
|
self.entries.push_back((self.next_id, entry));
|
||||||
self.next_id += 1;
|
self.next_id += 1;
|
||||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||||
}
|
}
|
||||||
@ -155,8 +156,8 @@ impl State {
|
|||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
||||||
|
|
||||||
// Drain next_batch_size entries
|
// Iterate on buffer
|
||||||
for (id, mut entry) in self.entries.drain(..max_batch_size) {
|
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_disconnected() {
|
if entry.response_tx.is_disconnected() {
|
||||||
@ -182,6 +183,16 @@ impl State {
|
|||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
// Insert in batch_entries IntMap
|
// Insert in batch_entries IntMap
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
|
if batch_requests.len() == max_batch_size {
|
||||||
|
// We have enough requests in the batch
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maybe all entries were dropped because their channel were closed
|
||||||
|
if batch_requests.is_empty() {
|
||||||
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final batch size once we dropped entries
|
// Final batch size once we dropped entries
|
||||||
@ -218,15 +229,16 @@ enum QueueCommand {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use tokio::sync::Semaphore;
|
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> Entry {
|
fn default_entry() -> (
|
||||||
let (response_tx, _) = flume::unbounded();
|
Entry,
|
||||||
|
flume::Receiver<Result<InferStreamResponse, InferError>>,
|
||||||
|
) {
|
||||||
|
let (response_tx, receiver_tx) = flume::unbounded();
|
||||||
|
|
||||||
Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: "".to_string(),
|
inputs: "".to_string(),
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
@ -251,13 +263,14 @@ mod tests {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
}
|
};
|
||||||
|
(entry, receiver_tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new();
|
let mut state = State::new();
|
||||||
let entry = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
assert_eq!(state.entries.len(), 0);
|
assert_eq!(state.entries.len(), 0);
|
||||||
@ -266,7 +279,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(state.next_id, 1);
|
assert_eq!(state.next_id, 1);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
let (id, _) = state.entries.remove(0);
|
let (id, _) = state.entries.remove(0).unwrap();
|
||||||
assert_eq!(id, 0);
|
assert_eq!(id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,8 +294,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new();
|
let mut state = State::new();
|
||||||
state.append(default_entry());
|
let (entry1, _guard1) = default_entry();
|
||||||
state.append(default_entry());
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
@ -297,21 +312,24 @@ mod tests {
|
|||||||
assert_eq!(state.entries.len(), 0);
|
assert_eq!(state.entries.len(), 0);
|
||||||
assert_eq!(state.next_batch_id, 1);
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
|
||||||
state.append(default_entry());
|
let (entry3, _guard3) = default_entry();
|
||||||
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), 2).is_none());
|
assert!(state.next_batch(Some(2), 2).is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
let (id, _) = state.entries.remove(0);
|
let (id, _) = state.entries.remove(0).unwrap();
|
||||||
assert_eq!(id, 2);
|
assert_eq!(id, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_max_size() {
|
fn test_next_batch_max_size() {
|
||||||
let mut state = State::new();
|
let mut state = State::new();
|
||||||
state.append(default_entry());
|
let (entry1, _guard1) = default_entry();
|
||||||
state.append(default_entry());
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
@ -323,7 +341,8 @@ mod tests {
|
|||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
assert_eq!(state.next_batch_id, 1);
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
|
||||||
state.append(default_entry());
|
let (entry3, _guard3) = default_entry();
|
||||||
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
@ -340,7 +359,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new();
|
let queue = Queue::new();
|
||||||
queue.append(default_entry());
|
let (entry, _guard) = default_entry();
|
||||||
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -354,8 +374,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new();
|
let queue = Queue::new();
|
||||||
queue.append(default_entry());
|
let (entry1, _guard1) = default_entry();
|
||||||
queue.append(default_entry());
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
@ -366,7 +388,8 @@ mod tests {
|
|||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
assert_eq!(batch.size, 2);
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
queue.append(default_entry());
|
let (entry3, _guard3) = default_entry();
|
||||||
|
queue.append(entry3);
|
||||||
|
|
||||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
||||||
}
|
}
|
||||||
@ -374,8 +397,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new();
|
let queue = Queue::new();
|
||||||
queue.append(default_entry());
|
let (entry1, _guard1) = default_entry();
|
||||||
queue.append(default_entry());
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
@ -383,7 +408,8 @@ mod tests {
|
|||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
assert_eq!(batch.size, 1);
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
queue.append(default_entry());
|
let (entry3, _guard3) = default_entry();
|
||||||
|
queue.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
@ -392,4 +418,13 @@ mod tests {
|
|||||||
assert_eq!(batch.id, 1);
|
assert_eq!(batch.id, 1);
|
||||||
assert_eq!(batch.size, 2);
|
assert_eq!(batch.size, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
|
let queue = Queue::new();
|
||||||
|
let (entry, _) = default_entry();
|
||||||
|
queue.append(entry);
|
||||||
|
|
||||||
|
assert!(queue.next_batch(None, 1).await.is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
include Makefile-transformers
|
include Makefile-transformers
|
||||||
include Makefile-flash-att
|
include Makefile-flash-att
|
||||||
|
|
||||||
|
unit-tests:
|
||||||
|
python -m pytest tests
|
||||||
|
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
|
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
|
||||||
|
@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
|
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
|
||||||
req_0 = copy(default_pb_request)
|
req_0 = copy(default_pb_request)
|
||||||
|
req_0.id = 1
|
||||||
req_1 = default_pb_request
|
req_1 = default_pb_request
|
||||||
req_1.id = 1
|
req_1.id = 2
|
||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
|||||||
|
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
|
||||||
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
|
assert all(
|
||||||
|
[
|
||||||
|
torch.equal(input_ids, all_input_ids[:, 0])
|
||||||
|
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
assert batch.input_lengths == [1]
|
assert batch.input_lengths == [1]
|
||||||
|
|
||||||
assert batch.size == default_pb_batch.size
|
assert len(batch) == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||||
|
|
||||||
assert batch.max_input_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
|
|
||||||
@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
assert not next_batch.keys_head_dim_last
|
assert not next_batch.keys_head_dim_last
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == len(next_batch)
|
||||||
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
||||||
assert len(next_batch.attention_mask[0]) == 11
|
assert len(next_batch.attention_mask[0]) == 11
|
||||||
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
||||||
@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||||||
assert torch.all(next_batch.attention_mask[0][:2] == 1)
|
assert torch.all(next_batch.attention_mask[0][:2] == 1)
|
||||||
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
||||||
|
|
||||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
assert next_batch.input_ids.shape == (len(next_batch), 1)
|
||||||
assert next_batch.input_ids[0, 0] == 10264
|
assert next_batch.input_ids[0, 0] == 10264
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
@ -269,6 +277,8 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
@ -290,6 +300,8 @@ def test_batch_concatenate(
|
|||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
- default_bloom_batch.stopping_criterias[0].max_new_tokens
|
- default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -68,7 +68,12 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
|
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
|
||||||
assert all([torch.equal(input_ids, all_input_ids[:, 0]) for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)])
|
assert all(
|
||||||
|
[
|
||||||
|
torch.equal(input_ids, all_input_ids[:, 0])
|
||||||
|
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
assert batch.input_lengths == [1]
|
assert batch.input_lengths == [1]
|
||||||
|
|
||||||
|
@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
|
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
|
||||||
req_0 = copy(default_pb_request)
|
req_0 = copy(default_pb_request)
|
||||||
|
req_0.id = 1
|
||||||
req_1 = default_pb_request
|
req_1 = default_pb_request
|
||||||
req_1.id = 1
|
req_1.id = 2
|
||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
|||||||
assert torch.all(batch.attention_mask[0][-2:] == 1)
|
assert torch.all(batch.attention_mask[0][-2:] == 1)
|
||||||
assert torch.all(batch.attention_mask[0][:-2] == 0)
|
assert torch.all(batch.attention_mask[0][:-2] == 0)
|
||||||
|
|
||||||
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
|
assert len(batch.decoder_input_ids) == default_pb_batch.size
|
||||||
assert batch.decoder_attention_mask is None
|
assert batch.decoder_attention_mask is None
|
||||||
assert batch.encoder_last_hidden_state is None
|
assert batch.encoder_last_hidden_state is None
|
||||||
|
|
||||||
@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
|||||||
assert batch.input_lengths == [2]
|
assert batch.input_lengths == [2]
|
||||||
assert batch.decoder_input_lengths == [1]
|
assert batch.decoder_input_lengths == [1]
|
||||||
|
|
||||||
assert batch.size == default_pb_batch.size
|
assert len(batch) == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||||
|
|
||||||
assert batch.max_input_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
|
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
|
||||||
@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
)
|
)
|
||||||
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
|
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
|
||||||
|
|
||||||
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
|
assert len(next_batch.decoder_input_ids) == len(next_batch)
|
||||||
assert next_batch.decoder_input_ids[0, 0] == 0
|
assert next_batch.all_decoder_input_ids[0][0] == 0
|
||||||
assert next_batch.decoder_input_ids[0, 1] == 259
|
assert next_batch.all_decoder_input_ids[0][1] == 259
|
||||||
assert next_batch.decoder_attention_mask is None
|
assert next_batch.decoder_attention_mask is None
|
||||||
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
|
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
|
||||||
|
|
||||||
@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all(
|
||||||
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[
|
[
|
||||||
p[2].shape == (next_batch.size, 6, sequence_length, 64)
|
p[2].shape == (len(next_batch), 6, sequence_length, 64)
|
||||||
for p in next_batch.past_key_values
|
for p in next_batch.past_key_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[
|
[
|
||||||
p[3].shape == (next_batch.size, 6, sequence_length, 64)
|
p[3].shape == (len(next_batch), 6, sequence_length, 64)
|
||||||
for p in next_batch.past_key_values
|
for p in next_batch.past_key_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
)
|
)
|
||||||
assert generations[1].generated_text.generated_tokens == 5
|
assert generations[1].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
@ -223,7 +226,8 @@ def test_batch_concatenate(
|
|||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
|
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
|
||||||
)
|
)
|
||||||
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
|
assert next_batch.all_decoder_input_ids[1][0] == 0
|
||||||
|
assert next_batch.all_decoder_input_ids[2][0] == 0
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
|
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
|
||||||
)
|
)
|
||||||
@ -258,16 +262,16 @@ def test_batch_concatenate(
|
|||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all(
|
||||||
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, past in enumerate(next_batch.past_key_values):
|
for i, past in enumerate(next_batch.past_key_values):
|
||||||
@ -306,6 +310,8 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
@ -314,6 +320,8 @@ def test_batch_concatenate(
|
|||||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generations[0].generated_text.generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
|
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
|
@ -428,7 +428,7 @@ class CausalLM(Model):
|
|||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], CausalLMBatch]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
@ -545,6 +545,10 @@ class CausalLM(Model):
|
|||||||
batch.token_offsets[i] = token_offset
|
batch.token_offsets[i] = token_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
if stopped:
|
||||||
|
return generations, None
|
||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
@ -559,4 +563,4 @@ class CausalLM(Model):
|
|||||||
# Update past key values
|
# Update past key values
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
|
||||||
return generations, batch if not stopped else None
|
return generations, batch
|
||||||
|
@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
offsets = []
|
||||||
token_offsets = []
|
token_offsets = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for i, r in enumerate(pb.requests):
|
||||||
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=list(all_input_ids),
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
token_offsets=token_offsets,
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
max_input_length=max_input_length.item(),
|
||||||
max_input_length=max_input_length,
|
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
|
|||||||
class Seq2SeqLMBatch(Batch):
|
class Seq2SeqLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
# Encoder values
|
# Encoder values
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
decoder_attention_mask: Optional[torch.Tensor]
|
decoder_attention_mask: Optional[torch.Tensor]
|
||||||
encoder_last_hidden_state: Optional[torch.Tensor]
|
encoder_last_hidden_state: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# All tokens
|
||||||
|
all_decoder_input_ids: List[torch.Tensor]
|
||||||
|
|
||||||
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
|
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
|
||||||
@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
size: int
|
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_decoder_input_length: int
|
max_decoder_input_length: int
|
||||||
padding_right_offset: int
|
padding_right_offset: int
|
||||||
@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id, requests=self.requests, size=len(self)
|
||||||
requests=self.requests,
|
|
||||||
size=self.size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
decoder_input_ids = []
|
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
token_offsets = []
|
token_offsets = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for i, r in enumerate(pb.requests):
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
# Decoder sequence only contains the bos_token
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
token_offsets.append(None)
|
token_offsets.append(None)
|
||||||
@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
decoder_input_ids = (
|
||||||
|
torch.tensor(tokenizer.bos_token_id, device=device)
|
||||||
|
.repeat(len(pb.requests))
|
||||||
|
.view(-1, 1)
|
||||||
|
)
|
||||||
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=tokenized_inputs["input_ids"],
|
input_ids=tokenized_inputs["input_ids"],
|
||||||
attention_mask=tokenized_inputs["attention_mask"],
|
attention_mask=tokenized_inputs["attention_mask"],
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
all_decoder_input_ids=list(all_decoder_input_ids),
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
encoder_last_hidden_state=None,
|
encoder_last_hidden_state=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
token_offsets=token_offsets,
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("filter")
|
||||||
|
def filter(
|
||||||
|
self, requests: List[generate_pb2.Request]
|
||||||
|
) -> Optional["Seq2SeqLMBatch"]:
|
||||||
|
if len(requests) == 0:
|
||||||
|
raise ValueError("Batch must have at least one request")
|
||||||
|
if len(requests) == len(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
keep_indices = []
|
||||||
|
|
||||||
|
# New values after filtering
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
input_lengths = []
|
||||||
|
decoder_input_lengths = []
|
||||||
|
offsets = []
|
||||||
|
token_offsets = []
|
||||||
|
|
||||||
|
all_decoder_input_ids = []
|
||||||
|
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
|
||||||
|
max_input_length = 0
|
||||||
|
max_decoder_input_length = 0
|
||||||
|
|
||||||
|
for i, r in enumerate(requests):
|
||||||
|
idx = self.requests_idx_mapping[r.id]
|
||||||
|
requests_idx_mapping[r.id] = i
|
||||||
|
keep_indices.append(idx)
|
||||||
|
|
||||||
|
offsets.append(self.offsets[idx])
|
||||||
|
token_offsets.append(self.token_offsets[idx])
|
||||||
|
|
||||||
|
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
|
||||||
|
|
||||||
|
request_input_length = self.input_lengths[idx]
|
||||||
|
input_lengths.append(request_input_length)
|
||||||
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
|
||||||
|
request_decoder_input_length = self.decoder_input_lengths[idx]
|
||||||
|
decoder_input_lengths.append(request_decoder_input_length)
|
||||||
|
max_decoder_input_length = max(
|
||||||
|
max_decoder_input_length, request_decoder_input_length
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
|
stopping_criterias.append(self.stopping_criterias[idx])
|
||||||
|
|
||||||
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
|
decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||||
|
attention_mask = self.attention_mask[keep_indices]
|
||||||
|
if self.decoder_attention_mask is not None:
|
||||||
|
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
|
||||||
|
else:
|
||||||
|
decoder_attention_mask = None
|
||||||
|
|
||||||
|
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
|
||||||
|
|
||||||
|
past_key_values = [
|
||||||
|
[t[keep_indices] for t in layer] for layer in self.past_key_values
|
||||||
|
]
|
||||||
|
|
||||||
|
return Seq2SeqLMBatch(
|
||||||
|
batch_id=self.batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
all_decoder_input_ids=all_decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
encoder_last_hidden_state=encoder_last_hidden_state,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
|
padding_right_offset=self.padding_right_offset,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||||
@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
max_decoder_input_length = 0
|
max_decoder_input_length = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
total_batch_size += batch.size
|
total_batch_size += len(batch)
|
||||||
max_input_length = max(max_input_length, batch.max_input_length)
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
max_decoder_input_length = max(
|
max_decoder_input_length = max(
|
||||||
max_decoder_input_length, batch.max_decoder_input_length
|
max_decoder_input_length, batch.max_decoder_input_length
|
||||||
@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
all_decoder_input_ids = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
offsets = []
|
||||||
@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
# Extend all list attributes
|
# Extend all list attributes
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
|
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
offsets.extend(batch.offsets)
|
||||||
@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
|
else:
|
||||||
|
# We need to offset the mapping for each batch by the cumulative batch size
|
||||||
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
|
requests_idx_mapping[k] = v + start_index
|
||||||
|
|
||||||
# Slicing end index for this batch
|
# Slicing end index for this batch
|
||||||
end_index = start_index + batch.size
|
end_index = start_index + len(batch)
|
||||||
|
|
||||||
# We only concatenate batches that did at least one step
|
# We only concatenate batches that did at least one step
|
||||||
if batch.encoder_last_hidden_state is None:
|
if batch.encoder_last_hidden_state is None:
|
||||||
@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = batch.decoder_input_ids.new_zeros(
|
decoder_input_ids = batch.decoder_input_ids.new_zeros(
|
||||||
(total_batch_size, max_decoder_input_length),
|
(total_batch_size, 1),
|
||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
decoder_input_ids[
|
decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
|
||||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
start_index:end_index, :, -batch.max_input_length :, :
|
start_index:end_index, :, -batch.max_input_length :, :
|
||||||
] = t[:, :, -batch.max_input_length :, :]
|
] = t[:, :, -batch.max_input_length :, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += len(batch)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
all_decoder_input_ids=all_decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
encoder_last_hidden_state=encoder_last_hidden_state,
|
encoder_last_hidden_state=encoder_last_hidden_state,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
token_offsets=token_offsets,
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -413,46 +513,25 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
decoder_attention_mask = None
|
decoder_attention_mask = None
|
||||||
|
|
||||||
# check if first forward or not
|
|
||||||
if batch.past_key_values is not None:
|
|
||||||
# Only take the last token
|
|
||||||
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
decoder_input_ids = batch.decoder_input_ids
|
|
||||||
|
|
||||||
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
||||||
# internally...
|
# internally...
|
||||||
if batch.encoder_last_hidden_state is not None:
|
if batch.encoder_last_hidden_state is not None:
|
||||||
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
|
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
|
||||||
else:
|
else:
|
||||||
encoder_last_hidden_state = batch.encoder_last_hidden_state
|
encoder_last_hidden_state = None
|
||||||
|
|
||||||
logits, encoder_last_hidden_state, past = self.forward(
|
logits, encoder_last_hidden_state, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
encoder_last_hidden_state,
|
encoder_last_hidden_state,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# List of indices to cache
|
|
||||||
next_batch_keep_indices = []
|
|
||||||
|
|
||||||
# New values for next forward
|
|
||||||
next_batch_input_lengths = []
|
|
||||||
next_batch_offsets = []
|
|
||||||
next_batch_token_offsets = []
|
|
||||||
next_batch_decoder_input_ids = []
|
|
||||||
next_batch_decoder_input_lengths = []
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
next_batch_size = 0
|
|
||||||
next_batch_max_input_length = 0
|
|
||||||
next_batch_max_decoder_input_length = 0
|
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
|
stopped = True
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -464,7 +543,7 @@ class Seq2SeqLM(Model):
|
|||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.decoder_input_ids,
|
batch.all_decoder_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -477,22 +556,24 @@ class Seq2SeqLM(Model):
|
|||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
decoder_input_ids,
|
all_decoder_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
decoder_input_ids.view(1, -1), logits
|
all_decoder_input_ids.view(1, -1), logits
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
|
all_decoder_input_ids = torch.cat(
|
||||||
|
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
||||||
|
)
|
||||||
new_decoder_input_length = decoder_input_length + 1
|
new_decoder_input_length = decoder_input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, offset, token_offset = self.decode_token(
|
||||||
decoder_input_ids, offset, token_offset
|
all_decoder_input_ids, offset, token_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -501,7 +582,7 @@ class Seq2SeqLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output_text = self.decode(decoder_input_ids[-decoder_input_length:])
|
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
@ -515,19 +596,7 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
# Keep request in the batch
|
# Keep request in the batch
|
||||||
generated_text = None
|
generated_text = None
|
||||||
next_batch_keep_indices.append(i)
|
stopped = False
|
||||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
|
||||||
next_batch_size += 1
|
|
||||||
next_batch_input_lengths.append(input_length)
|
|
||||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
|
||||||
next_batch_offsets.append(offset)
|
|
||||||
next_batch_token_offsets.append(token_offset)
|
|
||||||
next_batch_max_input_length = max(
|
|
||||||
next_batch_max_input_length, input_length
|
|
||||||
)
|
|
||||||
next_batch_max_decoder_input_length = max(
|
|
||||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
@ -551,69 +620,29 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
|
# Update values
|
||||||
|
batch.decoder_input_ids[i] = next_token_id
|
||||||
|
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||||
|
batch.input_lengths[i] = input_length
|
||||||
|
batch.decoder_input_lengths[i] = new_decoder_input_length
|
||||||
|
batch.offsets[i] = offset
|
||||||
|
batch.token_offsets[i] = token_offset
|
||||||
|
batch.max_input_length = max(batch.max_input_length, input_length)
|
||||||
|
batch.max_decoder_input_length = max(
|
||||||
|
batch.max_decoder_input_length, new_decoder_input_length
|
||||||
|
)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if stopped:
|
||||||
return generations, None
|
return generations, None
|
||||||
|
|
||||||
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
# We don't need input_ids after the prefill forward
|
||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
batch.input_ids = None
|
||||||
# from the values of the next batch
|
batch.encoder_last_hidden_state = encoder_last_hidden_state
|
||||||
if len(next_batch_keep_indices) != len(batch):
|
batch.past_key_values = past
|
||||||
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
|
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
|
||||||
if batch.decoder_attention_mask is not None:
|
|
||||||
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
|
|
||||||
next_batch_keep_indices
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
next_batch_decoder_attention_mask = None
|
|
||||||
|
|
||||||
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
|
|
||||||
next_batch_keep_indices
|
|
||||||
]
|
|
||||||
|
|
||||||
next_batch_past_key_values = [
|
|
||||||
[t[next_batch_keep_indices] for t in layer] for layer in past
|
|
||||||
]
|
|
||||||
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
|
||||||
next_batch_next_token_choosers = [
|
|
||||||
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
|
||||||
]
|
|
||||||
next_batch_stopping_criterias = [
|
|
||||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
next_batch_attention_mask = batch.attention_mask
|
|
||||||
next_batch_decoder_attention_mask = batch.decoder_attention_mask
|
|
||||||
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
|
|
||||||
next_batch_past_key_values = past
|
|
||||||
|
|
||||||
next_batch_requests = batch.requests
|
|
||||||
next_batch_next_token_choosers = batch.next_token_choosers
|
|
||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
|
||||||
|
|
||||||
# Update decoder_attention_mask as we added a new token to input_ids
|
# Update decoder_attention_mask as we added a new token to input_ids
|
||||||
if next_batch_decoder_attention_mask is not None:
|
if batch.decoder_attention_mask is not None:
|
||||||
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
next_batch = Seq2SeqLMBatch(
|
return generations, batch
|
||||||
batch_id=batch.batch_id,
|
|
||||||
requests=next_batch_requests,
|
|
||||||
input_ids=None,
|
|
||||||
attention_mask=next_batch_attention_mask,
|
|
||||||
decoder_input_ids=next_batch_decoder_input_ids,
|
|
||||||
decoder_attention_mask=next_batch_decoder_attention_mask,
|
|
||||||
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
|
|
||||||
past_key_values=next_batch_past_key_values,
|
|
||||||
input_lengths=next_batch_input_lengths,
|
|
||||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
|
||||||
offsets=next_batch_offsets,
|
|
||||||
token_offsets=next_batch_token_offsets,
|
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
|
||||||
size=next_batch_size,
|
|
||||||
max_input_length=next_batch_max_input_length,
|
|
||||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
|
||||||
padding_right_offset=batch.padding_right_offset - 1,
|
|
||||||
)
|
|
||||||
return generations, next_batch
|
|
||||||
|
Loading…
Reference in New Issue
Block a user