mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
working python tests
This commit is contained in:
parent
4a538cfa49
commit
42cdb734a5
@ -7,8 +7,8 @@ mod sharded_client;
|
|||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
|
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
|
||||||
StoppingCriteriaParameters, PrefillTokens
|
StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -5,7 +5,9 @@ use crate::{Db, Entry, Token};
|
|||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
|
use text_generation_client::{
|
||||||
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -88,7 +90,7 @@ impl Infer {
|
|||||||
input_length,
|
input_length,
|
||||||
time: Instant::now(),
|
time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit
|
_permit: permit,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the database that needs
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
@ -233,7 +235,7 @@ async fn batching_task(
|
|||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output=Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
|
@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
|
|||||||
|
|
||||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||||
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||||
|
|
||||||
assert generated_texts == []
|
assert len(generations) == len(default_bloom_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
assert not next_batch.keys_head_dim_last
|
assert not next_batch.keys_head_dim_last
|
||||||
|
|
||||||
@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||||||
assert all(
|
assert all(
|
||||||
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
|
assert all([generation.token_id.item() == 10264 for generation in generations])
|
||||||
|
assert all([generation.token_text == "Test" for generation in generations])
|
||||||
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||||
next_batch = default_bloom_batch
|
next_batch = default_bloom_batch
|
||||||
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(default_bloom_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
for i in range(
|
for i in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(default_multi_requests_bloom_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[1].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 1
|
- 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,17 +254,19 @@ def test_batch_concatenate(
|
|||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 3
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[2].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -262,19 +275,20 @@ def test_batch_concatenate(
|
|||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 2
|
- 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -284,18 +298,21 @@ def test_batch_concatenate(
|
|||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 4
|
- 4
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
|
|||||||
|
|
||||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(
|
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
||||||
default_causal_lm_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == next_batch.size
|
||||||
@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||||||
assert all(
|
assert all(
|
||||||
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
|
assert all([generation.token_id.item() == 13 for generation in generations])
|
||||||
|
assert all([generation.token_text == "." for generation in generations])
|
||||||
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token_completion(
|
def test_causal_lm_generate_token_completion(
|
||||||
@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
|
|||||||
):
|
):
|
||||||
next_batch = default_causal_lm_batch
|
next_batch = default_causal_lm_batch
|
||||||
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
for i in range(
|
for i in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "Test.java:784)"
|
assert generations[1].generated_text.text == "Test.java:784)"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
generations[1].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[1].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 1
|
- 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
generations[0].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -244,19 +248,20 @@ def test_batch_concatenate(
|
|||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 3
|
||||||
assert generated_texts[0].output_text == "Test.java:784)"
|
assert generations[2].generated_text.text == "Test.java:784)"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
generations[2].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[2].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -265,17 +270,17 @@ def test_batch_concatenate(
|
|||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 2
|
- 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -285,18 +290,19 @@ def test_batch_concatenate(
|
|||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 4
|
- 4
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
generations[0].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
|||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "def test_get_all_users_with_"
|
assert generations[0].generated_text.text == "def test_get_all_users_with_"
|
||||||
assert generated_texts[0].request == batch.requests[0]
|
assert generations[0].request_id == batch.requests[0].id
|
||||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== batch.stopping_criterias[0].max_new_tokens
|
== batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
|
|||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text
|
generations[0].generated_text.text
|
||||||
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == batch.requests[0]
|
assert generations[0].request_id == batch.requests[0].id
|
||||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== batch.stopping_criterias[0].max_new_tokens
|
== batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
|||||||
|
|
||||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
generations, next_batch = default_seq2seq_lm.generate_token(
|
||||||
default_seq2seq_lm_batch
|
default_seq2seq_lm_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, Seq2SeqLMBatch)
|
assert isinstance(next_batch, Seq2SeqLMBatch)
|
||||||
|
|
||||||
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
||||||
@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
for p in next_batch.past_key_values
|
for p in next_batch.past_key_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
|
assert all([generation.token_id.item() == 259 for generation in generations])
|
||||||
|
assert all([generation.token_text == "" for generation in generations])
|
||||||
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token_completion(
|
def test_seq2seq_lm_generate_token_completion(
|
||||||
@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
|
|||||||
):
|
):
|
||||||
next_batch = default_seq2seq_lm_batch
|
next_batch = default_seq2seq_lm_batch
|
||||||
for _ in range(6):
|
for _ in range(6):
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token_completion_multi(
|
def test_seq2seq_lm_generate_token_completion_multi(
|
||||||
@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||||
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "a few "
|
assert generations[1].generated_text.text == "a few "
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[1].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 5
|
assert generations[1].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[0].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
== default_multi_requests_seq2seq_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate(
|
def test_batch_concatenate(
|
||||||
@ -291,35 +296,35 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 3
|
||||||
assert generated_texts[0].output_text == "a few "
|
assert generations[2].generated_text.text == "a few "
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[2].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[0].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
== default_multi_requests_seq2seq_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
@ -343,9 +343,11 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
next_token_text = self.tokenizer.decode(
|
||||||
clean_up_tokenization_spaces=False,
|
next_token_id_squeezed,
|
||||||
skip_special_tokens=False)
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
@ -385,7 +387,9 @@ class CausalLM(Model):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||||
|
1, all_input_ids[1:]
|
||||||
|
).squeeze(1)[-new_input_length:-1].tolist()
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
|
@ -50,10 +50,10 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
@ -158,8 +158,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
input_ids[
|
input_ids[
|
||||||
start_index:end_index, -batch.max_input_length:
|
start_index:end_index, -batch.max_input_length :
|
||||||
] = batch.input_ids[:, -batch.max_input_length:]
|
] = batch.input_ids[:, -batch.max_input_length :]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@ -168,8 +168,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_input_length:
|
start_index:end_index, -batch.max_input_length :
|
||||||
] = batch.attention_mask[:, -batch.max_input_length:]
|
] = batch.attention_mask[:, -batch.max_input_length :]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
@ -178,8 +178,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
decoder_input_ids[
|
decoder_input_ids[
|
||||||
start_index:end_index, -batch.max_decoder_input_length:
|
start_index:end_index, -batch.max_decoder_input_length :
|
||||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:]
|
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
@ -191,13 +191,13 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||||
if batch.decoder_attention_mask is None:
|
if batch.decoder_attention_mask is None:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length:
|
start_index:end_index, -batch.max_decoder_input_length :
|
||||||
] = 1
|
] = 1
|
||||||
# If it exists, we need to index
|
# If it exists, we need to index
|
||||||
else:
|
else:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length:
|
start_index:end_index, -batch.max_decoder_input_length :
|
||||||
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:]
|
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if encoder_last_hidden_state is None:
|
if encoder_last_hidden_state is None:
|
||||||
@ -211,8 +211,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
encoder_last_hidden_state[
|
encoder_last_hidden_state[
|
||||||
start_index:end_index, -batch.max_input_length:, :
|
start_index:end_index, -batch.max_input_length :, :
|
||||||
] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :]
|
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
|
||||||
|
|
||||||
# Iterate over attention layers
|
# Iterate over attention layers
|
||||||
for j, past in enumerate(batch.past_key_values):
|
for j, past in enumerate(batch.past_key_values):
|
||||||
@ -238,11 +238,11 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
past_key_values[j][k][
|
past_key_values[j][k][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_decoder_input_length - 1):,
|
-(batch.max_decoder_input_length - 1) :,
|
||||||
:,
|
:,
|
||||||
] = t[:, :, -(batch.max_decoder_input_length - 1):, :]
|
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
|
||||||
|
|
||||||
# encoder past
|
# encoder past
|
||||||
for k, t in enumerate(past[2:]):
|
for k, t in enumerate(past[2:]):
|
||||||
@ -261,8 +261,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||||
|
|
||||||
past_key_values[j][idx][
|
past_key_values[j][idx][
|
||||||
start_index:end_index, :, -batch.max_input_length:, :
|
start_index:end_index, :, -batch.max_input_length :, :
|
||||||
] = t[:, :, -batch.max_input_length:, :]
|
] = t[:, :, -batch.max_input_length :, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
@ -322,13 +322,13 @@ class Seq2SeqLM(Model):
|
|||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask: Optional,
|
decoder_attention_mask: Optional,
|
||||||
encoder_last_hidden_state: Optional,
|
encoder_last_hidden_state: Optional,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
@ -359,7 +359,7 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
context_manager = (
|
context_manager = (
|
||||||
@ -405,14 +405,14 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||||
@ -424,15 +424,14 @@ class Seq2SeqLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(next_token_id_squeezed,
|
next_token_text = self.tokenizer.decode(
|
||||||
clean_up_tokenization_spaces=False,
|
next_token_id_squeezed,
|
||||||
skip_special_tokens=False)
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||||
next_token_id,
|
|
||||||
next_token_text
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
|
@ -61,6 +61,9 @@ class PrefillTokens:
|
|||||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Generation:
|
class Generation:
|
||||||
|
Loading…
Reference in New Issue
Block a user