working python tests

This commit is contained in:
OlivierDehaene 2023-01-30 12:18:53 +01:00
parent b2a468176d
commit adf80bc23d
9 changed files with 239 additions and 205 deletions

View File

@ -7,8 +7,8 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, PrefillTokens
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -5,7 +5,9 @@ use crate::{Db, Entry, Token};
use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
@ -88,7 +90,7 @@ impl Infer {
input_length,
time: Instant::now(),
batch_time: None,
_permit: permit
_permit: permit,
});
// Notify the background task that we have a new entry in the database that needs
@ -233,7 +235,7 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future(
future: impl Future<Output=Result<(Vec<Generation>, Option<Batch>), ClientError>>,
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
match future.await {

View File

@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
assert generated_texts == []
assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 10264 for generation in generations])
assert all([generation.token_text == "Test" for generation in generations])
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert len(generations) == 1
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
assert (
generated_texts[0].generated_tokens
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 1
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert len(generations) == 1
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert (
generated_texts[0].generated_tokens
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)
@ -243,17 +254,19 @@ def test_batch_concatenate(
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
assert (
generated_texts[0].generated_tokens
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
@ -262,19 +275,20 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert len(generations) == 2
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
@ -284,18 +298,21 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert len(generations) == 1
assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert (
generated_texts[0].generated_tokens
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generated_texts, next_batch = default_causal_lm.generate_token(
default_causal_lm_batch
)
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
assert generated_texts == []
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 13 for generation in generations])
assert all([generation.token_text == "." for generation in generations])
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(
@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "Test.java:784)"
assert len(generations) == 2
assert generations[1].generated_text.text == "Test.java:784)"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generated_texts[0].generated_tokens
generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 1
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
@ -244,19 +248,20 @@ def test_batch_concatenate(
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "Test.java:784)"
assert len(generations) == 3
assert generations[2].generated_text.text == "Test.java:784)"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generated_texts[0].generated_tokens
generations[2].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
@ -265,17 +270,17 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert len(generations) == 2
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
@ -285,18 +290,19 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "def test_get_all_users_with_"
assert generated_texts[0].request == batch.requests[0]
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert len(generations) == 1
assert generations[0].generated_text.text == "def test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)
@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert len(generations) == 1
assert (
generated_texts[0].output_text
generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
)
assert generated_texts[0].request == batch.requests[0]
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert generations[0].request_id == batch.requests[0].id
assert (
generated_texts[0].generated_tokens
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)

View File

@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generated_texts, next_batch = default_seq2seq_lm.generate_token(
generations, next_batch = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
assert generated_texts == []
assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
for p in next_batch.past_key_values
]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations])
assert generations[0].request_id == 0
def test_seq2seq_lm_generate_token_completion(
@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].generated_tokens == 7
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(
@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "a few "
assert len(generations) == 2
assert generations[1].generated_text.text == "a few "
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[1]
generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generated_texts[0].generated_tokens == 5
assert generations[1].generated_text.generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "a few weeks"
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[0]
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generated_texts[0].generated_tokens == 7
assert generations[0].generated_text.generated_tokens == 7
def test_batch_concatenate(
@ -291,35 +296,35 @@ def test_batch_concatenate(
)
for _ in range(3):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "a few "
assert len(generations) == 3
assert generations[2].generated_text.text == "a few "
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[1]
generations[2].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generated_texts[0].generated_tokens == 5
assert generations[2].generated_text.generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].generated_tokens == 7
assert len(generations) == 2
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "a few weeks"
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[0]
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generated_texts[0].generated_tokens == 7
assert generations[0].generated_text.generated_tokens == 7

View File

@ -343,9 +343,11 @@ class CausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False)
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
@ -385,7 +387,9 @@ class CausalLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,

View File

@ -158,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
)
# Copy to correct indices
input_ids[
start_index:end_index, -batch.max_input_length:
] = batch.input_ids[:, -batch.max_input_length:]
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor
if attention_mask is None:
@ -168,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
)
# Copy to correct indices
attention_mask[
start_index:end_index, -batch.max_input_length:
] = batch.attention_mask[:, -batch.max_input_length:]
start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :]
# Create padded tensor
if decoder_input_ids is None:
@ -178,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
)
# Copy to correct indices
decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length:
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:]
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
# Create padded tensor
if decoder_attention_mask is None:
@ -191,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
# this batch. All generations are of length `batch.max_decoder_input_length`.
if batch.decoder_attention_mask is None:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length:
start_index:end_index, -batch.max_decoder_input_length :
] = 1
# If it exists, we need to index
else:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length:
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:]
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
# Create padded tensor
if encoder_last_hidden_state is None:
@ -211,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length:, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :]
start_index:end_index, -batch.max_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
# Iterate over attention layers
for j, past in enumerate(batch.past_key_values):
@ -240,9 +240,9 @@ class Seq2SeqLMBatch(Batch):
past_key_values[j][k][
start_index:end_index,
:,
-(batch.max_decoder_input_length - 1):,
-(batch.max_decoder_input_length - 1) :,
:,
] = t[:, :, -(batch.max_decoder_input_length - 1):, :]
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
# encoder past
for k, t in enumerate(past[2:]):
@ -261,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values[j].append(t.new_zeros(padded_t_shape))
past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length:, :
] = t[:, :, -batch.max_input_length:, :]
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size
@ -424,15 +424,14 @@ class Seq2SeqLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False)
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id,
next_token_text
)
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding

View File

@ -61,6 +61,9 @@ class PrefillTokens:
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation: