From 42cdb734a53fdca5ae4785932f4e8485d24c0ddc Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 30 Jan 2023 12:18:53 +0100 Subject: [PATCH] working python tests --- router/client/src/lib.rs | 4 +- router/src/infer.rs | 8 +- server/tests/models/test_bloom.py | 105 ++++++++++++-------- server/tests/models/test_causal_lm.py | 100 ++++++++++--------- server/tests/models/test_santacoder.py | 30 +++--- server/tests/models/test_seq2seq_lm.py | 93 +++++++++-------- server/text_generation/models/causal_lm.py | 12 ++- server/text_generation/models/seq2seq_lm.py | 89 ++++++++--------- server/text_generation/models/types.py | 3 + 9 files changed, 239 insertions(+), 205 deletions(-) diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index d0724625..e0546b16 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,8 +7,8 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, Generation, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, PrefillTokens + Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request, + StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/infer.rs b/router/src/infer.rs index 065313d6..76bcc0ef 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -5,7 +5,9 @@ use crate::{Db, Entry, Token}; use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient}; +use text_generation_client::{ + Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, +}; use thiserror::Error; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; @@ -88,7 +90,7 @@ impl Infer { input_length, time: Instant::now(), batch_time: None, - _permit: permit + _permit: permit, }); // Notify the background task that we have a new entry in the database that needs @@ -233,7 +235,7 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the responses to Infer async fn wrap_future( - future: impl Future, Option), ClientError>>, + future: impl Future, Option), ClientError>>, entries: &mut IntMap, ) -> Option { match future.await { diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 1a788ce5..9f96efc3 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom): def test_causal_lm_generate_token(default_bloom, default_bloom_batch): sequence_length = len(default_bloom_batch.all_input_ids[0]) - generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) + generations, next_batch = default_bloom.generate_token(default_bloom_batch) - assert generated_texts == [] + assert len(generations) == len(default_bloom_batch) assert isinstance(next_batch, CausalLMBatch) assert not next_batch.keys_head_dim_last @@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert all( [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all([generation.token_id.item() == 10264 for generation in generations]) + assert all([generation.token_text == "Test" for generation in generations]) + assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): next_batch = default_bloom_batch for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(default_bloom_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_bloom_batch.requests[0] + assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(default_multi_requests_bloom_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTest" - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] + assert len(generations) == 2 + assert generations[1].generated_text.text == "TestTestTestTestTestTest" assert ( - generated_texts[0].generated_tokens + generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id + ) + assert ( + generations[1].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].generated_tokens + generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -243,17 +254,19 @@ def test_batch_concatenate( for _ in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTest" - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] + assert len(generations) == 3 + assert generations[2].generated_text.text == "TestTestTestTestTestTest" assert ( - generated_texts[0].generated_tokens + generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id + ) + assert ( + generations[2].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -262,19 +275,20 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 + assert len(generations) == 2 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_bloom_batch.requests[0] + assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -284,18 +298,21 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 4 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].generated_tokens + generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index bedb65ba..f9762b30 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): sequence_length = len(default_causal_lm_batch.all_input_ids[0]) - generated_texts, next_batch = default_causal_lm.generate_token( - default_causal_lm_batch - ) + generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) - assert generated_texts == [] + assert len(generations) == len(next_batch) assert isinstance(next_batch, CausalLMBatch) assert len(next_batch.all_input_ids) == next_batch.size @@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert all( [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all([generation.token_id.item() == 13 for generation in generations]) + assert all([generation.token_text == "." for generation in generations]) + assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion( @@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion( ): next_batch = default_causal_lm_batch for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." - assert generated_texts[0].request == default_causal_lm_batch.requests[0] - assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert len(generations) == 1 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784)" + assert len(generations) == 2 + assert generations[1].generated_text.text == "Test.java:784)" assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + generations[1].request_id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( - generated_texts[0].generated_tokens + generations[1].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." + assert len(generations) == 1 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -244,19 +248,20 @@ def test_batch_concatenate( for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784)" + assert len(generations) == 3 + assert generations[2].generated_text.text == "Test.java:784)" assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + generations[2].request_id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( - generated_texts[0].generated_tokens + generations[2].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -265,17 +270,17 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." - assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert len(generations) == 2 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -285,18 +290,19 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." + assert len(generations) == 1 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index acebec04..1b69477d 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_santacoder.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch = default_santacoder.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "def test_get_all_users_with_" - assert generated_texts[0].request == batch.requests[0] - assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert len(generations) == 1 + assert generations[0].generated_text.text == "def test_get_all_users_with_" + assert generations[0].request_id == batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == batch.stopping_criterias[0].max_new_tokens ) @@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion( next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_santacoder.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch = default_santacoder.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text + generations[0].generated_text.text == """defworldineProperty(exports, "__esModule", { value""" ) - assert generated_texts[0].request == batch.requests[0] - assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert generations[0].request_id == batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index de1a4829..22c6ac9c 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) - generated_texts, next_batch = default_seq2seq_lm.generate_token( + generations, next_batch = default_seq2seq_lm.generate_token( default_seq2seq_lm_batch ) - assert generated_texts == [] + assert len(generations) == len(next_batch) assert isinstance(next_batch, Seq2SeqLMBatch) assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) @@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) for p in next_batch.past_key_values ] ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all([generation.token_id.item() == 259 for generation in generations]) + assert all([generation.token_text == "" for generation in generations]) + assert generations[0].request_id == 0 def test_seq2seq_lm_generate_token_completion( @@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion( ): next_batch = default_seq2seq_lm_batch for _ in range(6): - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" - assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].generated_tokens == 7 + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" + assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id + assert generations[0].generated_text.generated_tokens == 7 def test_seq2seq_lm_generate_token_completion_multi( @@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = default_multi_requests_seq2seq_lm_batch for i in range(4): - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few " + assert len(generations) == 2 + assert generations[1].generated_text.text == "a few " assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[1] + generations[1].request_id + == default_multi_requests_seq2seq_lm_batch.requests[1].id ) - assert generated_texts[0].generated_tokens == 5 + assert generations[1].generated_text.generated_tokens == 5 - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_seq2seq_lm_batch.requests[0].id ) - assert generated_texts[0].generated_tokens == 7 + assert generations[0].generated_text.generated_tokens == 7 def test_batch_concatenate( @@ -291,35 +296,35 @@ def test_batch_concatenate( ) for _ in range(3): - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few " + assert len(generations) == 3 + assert generations[2].generated_text.text == "a few " assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[1] + generations[2].request_id + == default_multi_requests_seq2seq_lm_batch.requests[1].id ) - assert generated_texts[0].generated_tokens == 5 + assert generations[2].generated_text.generated_tokens == 5 - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" - assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].generated_tokens == 7 + assert len(generations) == 2 + assert generations[0].generated_text.text == "a few weeks" + assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id + assert generations[0].generated_text.generated_tokens == 7 - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_seq2seq_lm_batch.requests[0].id ) - assert generated_texts[0].generated_tokens == 7 + assert generations[0].generated_text.generated_tokens == 7 diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 1642602c..e34e5a38 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -343,9 +343,11 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode(next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False) + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) # Evaluate stopping criteria stop, reason = stopping_criteria( @@ -385,7 +387,9 @@ class CausalLM(Model): # Prefill if stopping_criteria.current_tokens == 1: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist() + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, all_input_ids[1:] + ).squeeze(1)[-new_input_length:-1].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 716c944f..6d5dc22e 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -50,10 +50,10 @@ class Seq2SeqLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "Seq2SeqLMBatch": """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" inputs = [] @@ -158,8 +158,8 @@ class Seq2SeqLMBatch(Batch): ) # Copy to correct indices input_ids[ - start_index:end_index, -batch.max_input_length: - ] = batch.input_ids[:, -batch.max_input_length:] + start_index:end_index, -batch.max_input_length : + ] = batch.input_ids[:, -batch.max_input_length :] # Create padded tensor if attention_mask is None: @@ -168,8 +168,8 @@ class Seq2SeqLMBatch(Batch): ) # Copy to correct indices attention_mask[ - start_index:end_index, -batch.max_input_length: - ] = batch.attention_mask[:, -batch.max_input_length:] + start_index:end_index, -batch.max_input_length : + ] = batch.attention_mask[:, -batch.max_input_length :] # Create padded tensor if decoder_input_ids is None: @@ -178,8 +178,8 @@ class Seq2SeqLMBatch(Batch): ) # Copy to correct indices decoder_input_ids[ - start_index:end_index, -batch.max_decoder_input_length: - ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length:] + start_index:end_index, -batch.max_decoder_input_length : + ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] # Create padded tensor if decoder_attention_mask is None: @@ -191,13 +191,13 @@ class Seq2SeqLMBatch(Batch): # this batch. All generations are of length `batch.max_decoder_input_length`. if batch.decoder_attention_mask is None: decoder_attention_mask[ - start_index:end_index, -batch.max_decoder_input_length: + start_index:end_index, -batch.max_decoder_input_length : ] = 1 # If it exists, we need to index else: decoder_attention_mask[ - start_index:end_index, -batch.max_decoder_input_length: - ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length:] + start_index:end_index, -batch.max_decoder_input_length : + ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] # Create padded tensor if encoder_last_hidden_state is None: @@ -211,8 +211,8 @@ class Seq2SeqLMBatch(Batch): # Copy to correct indices encoder_last_hidden_state[ - start_index:end_index, -batch.max_input_length:, : - ] = batch.encoder_last_hidden_state[:, -batch.max_input_length:, :] + start_index:end_index, -batch.max_input_length :, : + ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] # Iterate over attention layers for j, past in enumerate(batch.past_key_values): @@ -238,11 +238,11 @@ class Seq2SeqLMBatch(Batch): # We slice the past keys and values to remove the padding from previous batches past_key_values[j][k][ - start_index:end_index, - :, - -(batch.max_decoder_input_length - 1):, - :, - ] = t[:, :, -(batch.max_decoder_input_length - 1):, :] + start_index:end_index, + :, + -(batch.max_decoder_input_length - 1) :, + :, + ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :] # encoder past for k, t in enumerate(past[2:]): @@ -261,8 +261,8 @@ class Seq2SeqLMBatch(Batch): past_key_values[j].append(t.new_zeros(padded_t_shape)) past_key_values[j][idx][ - start_index:end_index, :, -batch.max_input_length:, : - ] = t[:, :, -batch.max_input_length:, :] + start_index:end_index, :, -batch.max_input_length :, : + ] = t[:, :, -batch.max_input_length :, :] start_index += batch.size @@ -322,13 +322,13 @@ class Seq2SeqLM(Model): return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) def forward( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask: Optional, + encoder_last_hidden_state: Optional, + past_key_values: Optional = None, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -359,7 +359,7 @@ class Seq2SeqLM(Model): ) def generate_token( - self, batch: Seq2SeqLMBatch + self, batch: Seq2SeqLMBatch ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( @@ -405,14 +405,14 @@ class Seq2SeqLM(Model): # For each member of the batch for i, ( - request, - input_length, - decoder_input_length, - logits, - next_token_chooser, - stopping_criteria, - input_tokens, - decoder_input_ids, + request, + input_length, + decoder_input_length, + logits, + next_token_chooser, + stopping_criteria, + input_tokens, + decoder_input_ids, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) @@ -424,15 +424,14 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode(next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False) + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text - ) + stop, reason = stopping_criteria(next_token_id, next_token_text) if stop: # Slice with decoder_input_length to remove padding diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 2407da4d..30cd716a 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -61,6 +61,9 @@ class PrefillTokens: ids=self.token_ids, logprobs=self.logprobs, texts=self.texts ) + def __len__(self): + return len(self.token_ids) + @dataclass class Generation: