working python tests

This commit is contained in:
OlivierDehaene 2023-01-30 12:18:53 +01:00
parent 4a538cfa49
commit 42cdb734a5
9 changed files with 239 additions and 205 deletions

View File

@ -7,8 +7,8 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request, Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
StoppingCriteriaParameters, PrefillTokens StoppingCriteriaParameters,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -5,7 +5,9 @@ use crate::{Db, Entry, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient}; use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
@ -88,7 +90,7 @@ impl Infer {
input_length, input_length,
time: Instant::now(), time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit _permit: permit,
}); });
// Notify the background task that we have a new entry in the database that needs // Notify the background task that we have a new entry in the database that needs
@ -233,7 +235,7 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the responses to Infer /// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future( async fn wrap_future(
future: impl Future<Output=Result<(Vec<Generation>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { match future.await {

View File

@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
def test_causal_lm_generate_token(default_bloom, default_bloom_batch): def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0]) sequence_length = len(default_bloom_batch.all_input_ids[0])
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) generations, next_batch = default_bloom.generate_token(default_bloom_batch)
assert generated_texts == [] assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last assert not next_batch.keys_head_dim_last
@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert all( assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
) )
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 10264 for generation in generations])
assert all([generation.token_text == "Test" for generation in generations])
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(default_bloom_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
for i in range( for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(default_multi_requests_bloom_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generations[1].generated_text.text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert ( assert (
generated_texts[0].generated_tokens generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 1 - 1
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
) )
@ -243,17 +254,19 @@ def test_batch_concatenate(
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 3
assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generations[2].generated_text.text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert ( assert (
generated_texts[0].generated_tokens generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
@ -262,19 +275,20 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2 - 2
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
@ -284,18 +298,21 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4 - 4
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
) )

View File

@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0]) sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generated_texts, next_batch = default_causal_lm.generate_token( generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
default_causal_lm_batch
)
assert generated_texts == [] assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == next_batch.size
@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert all( assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
) )
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 13 for generation in generations])
assert all([generation.token_text == "." for generation in generations])
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion( def test_causal_lm_generate_token_completion(
@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
): ):
next_batch = default_causal_lm_batch next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
for i in range( for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "Test.java:784)" assert generations[1].generated_text.text == "Test.java:784)"
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 1 - 1
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
@ -244,19 +248,20 @@ def test_batch_concatenate(
for _ in range( for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 3
assert generated_texts[0].output_text == "Test.java:784)" assert generations[2].generated_text.text == "Test.java:784)"
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[2].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
@ -265,17 +270,17 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2 - 2
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
@ -285,18 +290,19 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4 - 4
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )

View File

@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch = batch next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "def test_get_all_users_with_" assert generations[0].generated_text.text == "def test_get_all_users_with_"
assert generated_texts[0].request == batch.requests[0] assert generations[0].request_id == batch.requests[0].id
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens == batch.stopping_criterias[0].max_new_tokens
) )
@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
next_batch = batch next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value""" == """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
) )
assert generated_texts[0].request == batch.requests[0] assert generations[0].request_id == batch.requests[0].id
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens == batch.stopping_criterias[0].max_new_tokens
) )

View File

@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generated_texts, next_batch = default_seq2seq_lm.generate_token( generations, next_batch = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch default_seq2seq_lm_batch
) )
assert generated_texts == [] assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch) assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
for p in next_batch.past_key_values for p in next_batch.past_key_values
] ]
) )
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations])
assert generations[0].request_id == 0
def test_seq2seq_lm_generate_token_completion( def test_seq2seq_lm_generate_token_completion(
@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
): ):
next_batch = default_seq2seq_lm_batch next_batch = default_seq2seq_lm_batch
for _ in range(6): for _ in range(6):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi( def test_seq2seq_lm_generate_token_completion_multi(
@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4): for i in range(4):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "a few " assert generations[1].generated_text.text == "a few "
assert ( assert (
generated_texts[0].request generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1] == default_multi_requests_seq2seq_lm_batch.requests[1].id
) )
assert generated_texts[0].generated_tokens == 5 assert generations[1].generated_text.generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert ( assert (
generated_texts[0].request generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0] == default_multi_requests_seq2seq_lm_batch.requests[0].id
) )
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
def test_batch_concatenate( def test_batch_concatenate(
@ -291,35 +296,35 @@ def test_batch_concatenate(
) )
for _ in range(3): for _ in range(3):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 3
assert generated_texts[0].output_text == "a few " assert generations[2].generated_text.text == "a few "
assert ( assert (
generated_texts[0].request generations[2].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1] == default_multi_requests_seq2seq_lm_batch.requests[1].id
) )
assert generated_texts[0].generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert ( assert (
generated_texts[0].request generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0] == default_multi_requests_seq2seq_lm_batch.requests[0].id
) )
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7

View File

@ -343,9 +343,11 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(next_token_id_squeezed, next_token_text = self.tokenizer.decode(
clean_up_tokenization_spaces=False, next_token_id_squeezed,
skip_special_tokens=False) clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
@ -385,7 +387,9 @@ class CausalLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist() prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,

View File

@ -50,10 +50,10 @@ class Seq2SeqLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
@ -158,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
) )
# Copy to correct indices # Copy to correct indices
input_ids[ input_ids[
start_index:end_index, -batch.max_input_length: start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length:] ] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
@ -168,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
) )
# Copy to correct indices # Copy to correct indices
attention_mask[ attention_mask[
start_index:end_index, -batch.max_input_length: start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length:] ] = batch.attention_mask[:, -batch.max_input_length :]
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
@ -178,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
) )
# Copy to correct indices # Copy to correct indices
decoder_input_ids[ decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length: start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:] ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
# Create padded tensor # Create padded tensor
if decoder_attention_mask is None: if decoder_attention_mask is None:
@ -191,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
# this batch. All generations are of length `batch.max_decoder_input_length`. # this batch. All generations are of length `batch.max_decoder_input_length`.
if batch.decoder_attention_mask is None: if batch.decoder_attention_mask is None:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length: start_index:end_index, -batch.max_decoder_input_length :
] = 1 ] = 1
# If it exists, we need to index # If it exists, we need to index
else: else:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length: start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:] ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
# Create padded tensor # Create padded tensor
if encoder_last_hidden_state is None: if encoder_last_hidden_state is None:
@ -211,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length:, : start_index:end_index, -batch.max_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :] ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
# Iterate over attention layers # Iterate over attention layers
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
@ -238,11 +238,11 @@ class Seq2SeqLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
past_key_values[j][k][ past_key_values[j][k][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_decoder_input_length - 1):, -(batch.max_decoder_input_length - 1) :,
:, :,
] = t[:, :, -(batch.max_decoder_input_length - 1):, :] ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
# encoder past # encoder past
for k, t in enumerate(past[2:]): for k, t in enumerate(past[2:]):
@ -261,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values[j].append(t.new_zeros(padded_t_shape)) past_key_values[j].append(t.new_zeros(padded_t_shape))
past_key_values[j][idx][ past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length:, : start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length:, :] ] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size start_index += batch.size
@ -322,13 +322,13 @@ class Seq2SeqLM(Model):
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
def forward( def forward(
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask: Optional, decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional, encoder_last_hidden_state: Optional,
past_key_values: Optional = None, past_key_values: Optional = None,
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
@ -359,7 +359,7 @@ class Seq2SeqLM(Model):
) )
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
@ -405,14 +405,14 @@ class Seq2SeqLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length, input_length,
decoder_input_length, decoder_input_length,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
input_tokens, input_tokens,
decoder_input_ids, decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
@ -424,15 +424,14 @@ class Seq2SeqLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(next_token_id_squeezed, next_token_text = self.tokenizer.decode(
clean_up_tokenization_spaces=False, next_token_id_squeezed,
skip_special_tokens=False) clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(next_token_id, next_token_text)
next_token_id,
next_token_text
)
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding

View File

@ -61,6 +61,9 @@ class PrefillTokens:
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
) )
def __len__(self):
return len(self.token_ids)
@dataclass @dataclass
class Generation: class Generation: