mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
working python tests
This commit is contained in:
parent
4a538cfa49
commit
42cdb734a5
@ -7,8 +7,8 @@ mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, PrefillTokens
|
||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -5,7 +5,9 @@ use crate::{Db, Entry, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
@ -88,7 +90,7 @@ impl Infer {
|
||||
input_length,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit
|
||||
_permit: permit,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the database that needs
|
||||
|
@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
|
||||
|
||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||
|
||||
assert generated_texts == []
|
||||
assert len(generations) == len(default_bloom_batch)
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
assert not next_batch.keys_head_dim_last
|
||||
|
||||
@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
assert all(
|
||||
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all([generation.generated_text is None for generation in generations])
|
||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||
assert all([generation.token_id.item() == 10264 for generation in generations])
|
||||
assert all([generation.token_text == "Test" for generation in generations])
|
||||
assert generations[0].request_id == 0
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||
next_batch = default_bloom_batch
|
||||
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(default_bloom_batch)
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
for i in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(default_multi_requests_bloom_batch)
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||
)
|
||||
assert (
|
||||
generations[1].generated_text.generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
)
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -243,17 +254,19 @@ def test_batch_concatenate(
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||
)
|
||||
assert (
|
||||
generations[2].generated_text.generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
@ -262,19 +275,20 @@ def test_batch_concatenate(
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 2
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert len(generations) == 2
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -284,18 +298,21 @@ def test_batch_concatenate(
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 4
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
)
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
|
||||
|
||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(
|
||||
default_causal_lm_batch
|
||||
)
|
||||
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
||||
|
||||
assert generated_texts == []
|
||||
assert len(generations) == len(next_batch)
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
assert all(
|
||||
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all([generation.generated_text is None for generation in generations])
|
||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||
assert all([generation.token_id.item() == 13 for generation in generations])
|
||||
assert all([generation.token_text == "." for generation in generations])
|
||||
assert generations[0].request_id == 0
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion(
|
||||
@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
|
||||
):
|
||||
next_batch = default_causal_lm_batch
|
||||
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
for i in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "Test.java:784)"
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "Test.java:784)"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
generations[1].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[1].generated_text.generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
generations[0].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -244,19 +248,20 @@ def test_batch_concatenate(
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "Test.java:784)"
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "Test.java:784)"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
generations[2].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[2].generated_text.generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
@ -265,17 +270,17 @@ def test_batch_concatenate(
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 2
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert len(generations) == 2
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -285,18 +290,19 @@ def test_batch_concatenate(
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 4
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
generations[0].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
||||
next_batch = batch
|
||||
|
||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "def test_get_all_users_with_"
|
||||
assert generated_texts[0].request == batch.requests[0]
|
||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "def test_get_all_users_with_"
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
|
||||
next_batch = batch
|
||||
|
||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generated_texts[0].output_text
|
||||
generations[0].generated_text.text
|
||||
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
||||
)
|
||||
assert generated_texts[0].request == batch.requests[0]
|
||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
generated_texts[0].generated_tokens
|
||||
generations[0].generated_text.generated_tokens
|
||||
== batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
||||
|
||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(
|
||||
default_seq2seq_lm_batch
|
||||
)
|
||||
|
||||
assert generated_texts == []
|
||||
assert len(generations) == len(next_batch)
|
||||
assert isinstance(next_batch, Seq2SeqLMBatch)
|
||||
|
||||
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
||||
@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||
for p in next_batch.past_key_values
|
||||
]
|
||||
)
|
||||
assert all([generation.generated_text is None for generation in generations])
|
||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||
assert all([generation.token_id.item() == 259 for generation in generations])
|
||||
assert all([generation.token_text == "" for generation in generations])
|
||||
assert generations[0].request_id == 0
|
||||
|
||||
|
||||
def test_seq2seq_lm_generate_token_completion(
|
||||
@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
|
||||
):
|
||||
next_batch = default_seq2seq_lm_batch
|
||||
for _ in range(6):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "a few weeks"
|
||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
|
||||
def test_seq2seq_lm_generate_token_completion_multi(
|
||||
@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||
|
||||
for i in range(4):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "a few "
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "a few "
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
||||
generations[1].request_id
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||
)
|
||||
assert generated_texts[0].generated_tokens == 5
|
||||
assert generations[1].generated_text.generated_tokens == 5
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "a few weeks"
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
||||
generations[0].request_id
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0].id
|
||||
)
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
|
||||
def test_batch_concatenate(
|
||||
@ -291,35 +296,35 @@ def test_batch_concatenate(
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "a few "
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "a few "
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
||||
generations[2].request_id
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||
)
|
||||
assert generated_texts[0].generated_tokens == 5
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
assert len(generations) == 2
|
||||
assert generations[0].generated_text.text == "a few weeks"
|
||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output_text == "a few weeks"
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "a few weeks"
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
||||
generations[0].request_id
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0].id
|
||||
)
|
||||
assert generated_texts[0].generated_tokens == 7
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
@ -343,9 +343,11 @@ class CausalLM(Model):
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False)
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
@ -385,7 +387,9 @@ class CausalLM(Model):
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||
1, all_input_ids[1:]
|
||||
).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
|
@ -424,15 +424,14 @@ class Seq2SeqLM(Model):
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False)
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text
|
||||
)
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
|
@ -61,6 +61,9 @@ class PrefillTokens:
|
||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
|
Loading…
Reference in New Issue
Block a user