From cdc005bcb0624b4202d8c05db184f85438cbb851 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 2 Jun 2023 16:36:32 +0200 Subject: [PATCH] rename var --- Makefile | 1 + clients/python/README.md | 16 +++++++-------- clients/python/tests/test_client.py | 16 ++++++++------- clients/python/text_generation/client.py | 20 +++++++++---------- clients/python/text_generation/types.py | 16 +++++++-------- integration-tests/conftest.py | 4 +++- integration-tests/models/test_bloom_560m.py | 4 ++-- .../models/test_bloom_560m_sharded.py | 2 +- integration-tests/models/test_flash_falcon.py | 4 ++-- integration-tests/models/test_flash_llama.py | 4 ++-- integration-tests/models/test_flash_neox.py | 2 +- .../models/test_flash_neox_sharded.py | 2 +- .../models/test_flash_santacoder.py | 2 +- .../models/test_flash_starcoder.py | 4 ++-- integration-tests/models/test_mt0_base.py | 4 ++-- integration-tests/models/test_t5_sharded.py | 2 +- integration-tests/requirements.txt | 2 +- router/src/lib.rs | 4 ++-- router/src/queue.rs | 3 ++- router/src/server.rs | 4 ++-- router/src/validation.rs | 8 ++++---- 21 files changed, 65 insertions(+), 59 deletions(-) diff --git a/Makefile b/Makefile index 7309aaee..a33aba17 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ install-server: install-integration-tests: cd integration-tests && pip install -r requirements.txt + cd clients/python && pip install . install-router: cd router && cargo install --path . diff --git a/clients/python/README.md b/clients/python/README.md index 49a5182d..4e0e564c 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -138,11 +138,11 @@ class Parameters: best_of: Optional[int] # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) watermark: bool - # Get prompt token logprobs and ids - prefill_details: bool + # Get decoder input token logprobs and ids + decoder_input_details: bool -# Prompt tokens -class PrefillToken: +# Decoder input tokens +class InputToken: # Token ID from the model tokenizer id: int # Token text @@ -185,8 +185,8 @@ class BestOfSequence: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens, empty if prefill_details is False - prefill: Optional[List[PrefillToken]] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] @@ -199,8 +199,8 @@ class Details: generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens, empty if prefill_details is False - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] # Additional sequences when using the `best_of` parameter diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 10f0a825..1e25e1b1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -2,19 +2,19 @@ import pytest from text_generation import Client, AsyncClient from text_generation.errors import NotFoundError, ValidationError -from text_generation.types import FinishReason, PrefillToken, Token +from text_generation.types import FinishReason, InputToken def test_generate(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", max_new_tokens=1, prefill_details=True) + response = client.generate("test", max_new_tokens=1, decoder_input_details=True) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 assert response.details.tokens[0].text == " " @@ -24,7 +24,7 @@ def test_generate(flan_t5_xxl_url, hf_headers): def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True ) assert response.details.seed is not None @@ -75,14 +75,16 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio async def test_generate_async(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) - response = await client.generate("test", max_new_tokens=1, prefill_details=True) + response = await client.generate( + "test", max_new_tokens=1, decoder_input_details=True + ) assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 assert response.details.tokens[0].text == " " @@ -93,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): client = AsyncClient(flan_t5_xxl_url, hf_headers) response = await client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True ) assert response.details.seed is not None diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index be85d26b..bf045d47 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -74,7 +74,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, - prefill_details: bool = False, + decoder_input_details: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -111,8 +111,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - prefill_details (`bool`): - Return the prefill token log probabilities + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids Returns: Response: generated response @@ -133,7 +133,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, - prefill_details=prefill_details, + decoder_input_details=decoder_input_details, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -206,7 +206,7 @@ class Client: parameters = Parameters( best_of=None, details=True, - prefill_details=False, + decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -316,7 +316,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, - prefill_details: bool = False, + decoder_input_details: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -353,8 +353,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - prefill_details (`bool`): - Return the prefill token log probabilities + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids Returns: Response: generated response @@ -363,7 +363,7 @@ class AsyncClient: parameters = Parameters( best_of=best_of, details=True, - prefill_details=prefill_details, + decoder_input_details=decoder_input_details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -446,7 +446,7 @@ class AsyncClient: parameters = Parameters( best_of=None, details=True, - prefill_details=False, + decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 2f78c033..548f0b63 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -37,8 +37,8 @@ class Parameters(BaseModel): watermark: bool = False # Get generation details details: bool = False - # Get prefill details - prefill_details: bool = False + # Get decoder input token logprobs and ids + decoder_input_details: bool = False @validator("best_of") def valid_best_of(cls, field_value, values): @@ -131,8 +131,8 @@ class Request(BaseModel): return field_value -# Prompt tokens -class PrefillToken(BaseModel): +# Decoder input tokens +class InputToken(BaseModel): # Token ID from the model tokenizer id: int # Token text @@ -175,8 +175,8 @@ class BestOfSequence(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens, empty if prefill_details is False - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] @@ -189,8 +189,8 @@ class Details(BaseModel): generated_tokens: int # Sampling seed if sampling was activated seed: Optional[int] - # Prompt tokens, empty if prefill_details is False - prefill: List[PrefillToken] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] # Generated tokens tokens: List[Token] # Additional sequences when using the `best_of` parameter diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 36d3eac3..c12e928f 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -332,7 +332,9 @@ def generate_load(): client: AsyncClient, prompt: str, max_new_tokens: int, n: int ) -> List[Response]: futures = [ - client.generate(prompt, max_new_tokens=max_new_tokens, prefill_details=True) + client.generate( + prompt, max_new_tokens=max_new_tokens, decoder_input_details=True + ) for _ in range(n) ] diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index d7fa6cdb..bdcbdc78 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -19,7 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, - prefill_details=True, + decoder_input_details=True, seed=0, ) @@ -41,7 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): truncate=5, typical_p=0.9, watermark=True, - prefill_details=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index 0781b8d9..3995f9e5 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -19,7 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): "Pour déguster un ortolan, il faut tout d'abord", max_new_tokens=10, top_p=0.9, - prefill_details=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py index ebb67cba..eac91984 100644 --- a/integration-tests/models/test_flash_falcon.py +++ b/integration-tests/models/test_flash_falcon.py @@ -19,7 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): response = await flash_falcon.generate( "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", max_new_tokens=10, - prefill_details=True, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 @@ -41,7 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): truncate=5, typical_p=0.9, watermark=True, - prefill_details=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index fc16d172..c69314ff 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -17,7 +17,7 @@ async def flash_llama(flash_llama_handle): @pytest.mark.private async def test_flash_llama(flash_llama, response_snapshot): response = await flash_llama.generate( - "Test request", max_new_tokens=10, prefill_details=True + "Test request", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 @@ -39,7 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): truncate=5, typical_p=0.9, watermark=True, - prefill_details=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index 2c165c3a..ff9b9763 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -18,7 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot): response = await flash_neox.generate( "<|USER|>What's your mood today?<|ASSISTANT|>", max_new_tokens=10, - prefill_details=True, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py index b5106d11..8a491915 100644 --- a/integration-tests/models/test_flash_neox_sharded.py +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -18,7 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): response = await flash_neox_sharded.generate( "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", max_new_tokens=10, - prefill_details=True, + decoder_input_details=True, ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index 90146969..0f005f15 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -16,7 +16,7 @@ async def flash_santacoder(flash_santacoder_handle): @pytest.mark.asyncio async def test_flash_santacoder(flash_santacoder, response_snapshot): response = await flash_santacoder.generate( - "def print_hello", max_new_tokens=10, prefill_details=True + "def print_hello", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index c5910fcc..64e8b27c 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -17,7 +17,7 @@ async def flash_starcoder(flash_starcoder_handle): @pytest.mark.private async def test_flash_starcoder(flash_starcoder, response_snapshot): response = await flash_starcoder.generate( - "def print_hello", max_new_tokens=10, prefill_details=True + "def print_hello", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 @@ -32,7 +32,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot max_new_tokens=60, temperature=0.2, top_p=0.95, - prefill_details=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 29de2909..12f23e4c 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -19,7 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot): "Why is the sky blue?", max_new_tokens=10, top_p=0.9, - prefill_details=True, + decoder_input_details=True, seed=0, ) @@ -41,7 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): truncate=5, typical_p=0.9, watermark=True, - prefill_details=True, + decoder_input_details=True, seed=0, ) diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py index ff9b8eb7..7c288b23 100644 --- a/integration-tests/models/test_t5_sharded.py +++ b/integration-tests/models/test_t5_sharded.py @@ -18,7 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot): response = await t5_sharded.generate( "Please answer the following question. What is the boiling point of Nitrogen?", max_new_tokens=10, - prefill_details=True, + decoder_input_details=True, ) assert response == response_snapshot diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt index 051730ff..2f36d5d6 100644 --- a/integration-tests/requirements.txt +++ b/integration-tests/requirements.txt @@ -1,5 +1,5 @@ syrupy -text-generation==0.5.2 +text-generation pytest pytest-asyncio==0.17.2 docker \ No newline at end of file diff --git a/router/src/lib.rs b/router/src/lib.rs index 4efe66ce..67fff017 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -126,7 +126,7 @@ pub(crate) struct GenerateParameters { pub details: bool, #[serde(default)] #[schema(default = "true")] - pub prefill_details: bool, + pub decoder_input_details: bool, #[serde(default)] #[schema( exclusive_minimum = 0, @@ -156,7 +156,7 @@ fn default_parameters() -> GenerateParameters { truncate: None, watermark: false, details: false, - prefill_details: false, + decoder_input_details: false, seed: None, } } diff --git a/router/src/queue.rs b/router/src/queue.rs index b8470ebe..03807933 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -201,7 +201,7 @@ impl State { batch_requests.push(Request { id, - prefill_logprobs: entry.request.prefill_details, + prefill_logprobs: entry.request.decoder_input_details, inputs: entry.request.inputs.clone(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), @@ -282,6 +282,7 @@ mod tests { inputs: "".to_string(), input_length: 0, truncate: 0, + decoder_input_details: false, parameters: NextTokenChooserParameters { temperature: 0.0, top_k: 0, diff --git a/router/src/server.rs b/router/src/server.rs index f0f205c5..10c0ba3c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -160,7 +160,7 @@ async fn generate( add_prompt = Some(req.0.inputs.clone()); } - let details = req.0.parameters.details; + let details = req.0.parameters.details || req.0.parameters.decoder_input_details; // Inference let (response, best_of_responses) = match req.0.parameters.best_of { @@ -369,7 +369,7 @@ async fn generate_stream( metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); yield Ok(Event::from(err)); - } else if req.0.parameters.prefill_details { + } else if req.0.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); diff --git a/router/src/validation.rs b/router/src/validation.rs index 63bf78a3..8843c6a8 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -145,7 +145,7 @@ impl Validation { truncate, seed, watermark, - prefill_details, + decoder_input_details, .. } = request.parameters; @@ -262,7 +262,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, - prefill_details, + decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, @@ -337,7 +337,7 @@ pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32, pub truncate: u32, - pub prefill_details: bool, + pub decoder_input_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, } @@ -354,7 +354,7 @@ pub enum ValidationError { BestOfSeed, #[error("`best_of` != 1 is not supported when streaming tokens")] BestOfStream, - #[error("`prefill_details` == true is not supported when streaming tokens")] + #[error("`decoder_input_details` == true is not supported when streaming tokens")] PrefillDetailsStream, #[error("`temperature` must be strictly positive")] Temperature,