rename var

This commit is contained in:
OlivierDehaene 2023-06-02 16:36:32 +02:00
parent 62ff3816fb
commit cdc005bcb0
21 changed files with 65 additions and 59 deletions

View File

@ -3,6 +3,7 @@ install-server:
install-integration-tests: install-integration-tests:
cd integration-tests && pip install -r requirements.txt cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
install-router: install-router:
cd router && cargo install --path . cd router && cargo install --path .

View File

@ -138,11 +138,11 @@ class Parameters:
best_of: Optional[int] best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool watermark: bool
# Get prompt token logprobs and ids # Get decoder input token logprobs and ids
prefill_details: bool decoder_input_details: bool
# Prompt tokens # Decoder input tokens
class PrefillToken: class InputToken:
# Token ID from the model tokenizer # Token ID from the model tokenizer
id: int id: int
# Token text # Token text
@ -185,8 +185,8 @@ class BestOfSequence:
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens, empty if prefill_details is False # Decoder input tokens, empty if decoder_input_details is False
prefill: Optional[List[PrefillToken]] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
@ -199,8 +199,8 @@ class Details:
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens, empty if prefill_details is False # Decoder input tokens, empty if decoder_input_details is False
prefill: List[PrefillToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Additional sequences when using the `best_of` parameter # Additional sequences when using the `best_of` parameter

View File

@ -2,19 +2,19 @@ import pytest
from text_generation import Client, AsyncClient from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token from text_generation.types import FinishReason, InputToken
def test_generate(flan_t5_xxl_url, hf_headers): def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, prefill_details=True) response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
assert response.generated_text == "" assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " " assert response.details.tokens[0].text == " "
@ -24,7 +24,7 @@ def test_generate(flan_t5_xxl_url, hf_headers):
def test_generate_best_of(flan_t5_xxl_url, hf_headers): def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate( response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
) )
assert response.details.seed is not None assert response.details.seed is not None
@ -75,14 +75,16 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1, prefill_details=True) response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True
)
assert response.generated_text == "" assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " " assert response.details.tokens[0].text == " "
@ -93,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate( response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, prefill_details=True "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
) )
assert response.details.seed is not None assert response.details.seed is not None

View File

@ -74,7 +74,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
prefill_details: bool = False, decoder_input_details: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -111,8 +111,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
prefill_details (`bool`): decoder_input_details (`bool`):
Return the prefill token log probabilities Return the decoder input token logprobs and ids
Returns: Returns:
Response: generated response Response: generated response
@ -133,7 +133,7 @@ class Client:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
prefill_details=prefill_details, decoder_input_details=decoder_input_details,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -206,7 +206,7 @@ class Client:
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=True,
prefill_details=False, decoder_input_details=False,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -316,7 +316,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
prefill_details: bool = False, decoder_input_details: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -353,8 +353,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
prefill_details (`bool`): decoder_input_details (`bool`):
Return the prefill token log probabilities Return the decoder input token logprobs and ids
Returns: Returns:
Response: generated response Response: generated response
@ -363,7 +363,7 @@ class AsyncClient:
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
details=True, details=True,
prefill_details=prefill_details, decoder_input_details=decoder_input_details,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -446,7 +446,7 @@ class AsyncClient:
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=True,
prefill_details=False, decoder_input_details=False,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,

View File

@ -37,8 +37,8 @@ class Parameters(BaseModel):
watermark: bool = False watermark: bool = False
# Get generation details # Get generation details
details: bool = False details: bool = False
# Get prefill details # Get decoder input token logprobs and ids
prefill_details: bool = False decoder_input_details: bool = False
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
@ -131,8 +131,8 @@ class Request(BaseModel):
return field_value return field_value
# Prompt tokens # Decoder input tokens
class PrefillToken(BaseModel): class InputToken(BaseModel):
# Token ID from the model tokenizer # Token ID from the model tokenizer
id: int id: int
# Token text # Token text
@ -175,8 +175,8 @@ class BestOfSequence(BaseModel):
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens, empty if prefill_details is False # Decoder input tokens, empty if decoder_input_details is False
prefill: List[PrefillToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
@ -189,8 +189,8 @@ class Details(BaseModel):
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated # Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens, empty if prefill_details is False # Decoder input tokens, empty if decoder_input_details is False
prefill: List[PrefillToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Additional sequences when using the `best_of` parameter # Additional sequences when using the `best_of` parameter

View File

@ -332,7 +332,9 @@ def generate_load():
client: AsyncClient, prompt: str, max_new_tokens: int, n: int client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [
client.generate(prompt, max_new_tokens=max_new_tokens, prefill_details=True) client.generate(
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
)
for _ in range(n) for _ in range(n)
] ]

View File

@ -19,7 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord", "Pour déguster un ortolan, il faut tout d'abord",
max_new_tokens=10, max_new_tokens=10,
top_p=0.9, top_p=0.9,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )
@ -41,7 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
truncate=5, truncate=5,
typical_p=0.9, typical_p=0.9,
watermark=True, watermark=True,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )

View File

@ -19,7 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord", "Pour déguster un ortolan, il faut tout d'abord",
max_new_tokens=10, max_new_tokens=10,
top_p=0.9, top_p=0.9,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )

View File

@ -19,7 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
response = await flash_falcon.generate( response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10, max_new_tokens=10,
prefill_details=True, decoder_input_details=True,
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
@ -41,7 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
truncate=5, truncate=5,
typical_p=0.9, typical_p=0.9,
watermark=True, watermark=True,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )

View File

@ -17,7 +17,7 @@ async def flash_llama(flash_llama_handle):
@pytest.mark.private @pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot): async def test_flash_llama(flash_llama, response_snapshot):
response = await flash_llama.generate( response = await flash_llama.generate(
"Test request", max_new_tokens=10, prefill_details=True "Test request", max_new_tokens=10, decoder_input_details=True
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
@ -39,7 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
truncate=5, truncate=5,
typical_p=0.9, typical_p=0.9,
watermark=True, watermark=True,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )

View File

@ -18,7 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
response = await flash_neox.generate( response = await flash_neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>", "<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10, max_new_tokens=10,
prefill_details=True, decoder_input_details=True,
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10

View File

@ -18,7 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate( response = await flash_neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10, max_new_tokens=10,
prefill_details=True, decoder_input_details=True,
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10

View File

@ -16,7 +16,7 @@ async def flash_santacoder(flash_santacoder_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_santacoder(flash_santacoder, response_snapshot): async def test_flash_santacoder(flash_santacoder, response_snapshot):
response = await flash_santacoder.generate( response = await flash_santacoder.generate(
"def print_hello", max_new_tokens=10, prefill_details=True "def print_hello", max_new_tokens=10, decoder_input_details=True
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10

View File

@ -17,7 +17,7 @@ async def flash_starcoder(flash_starcoder_handle):
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder(flash_starcoder, response_snapshot): async def test_flash_starcoder(flash_starcoder, response_snapshot):
response = await flash_starcoder.generate( response = await flash_starcoder.generate(
"def print_hello", max_new_tokens=10, prefill_details=True "def print_hello", max_new_tokens=10, decoder_input_details=True
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
@ -32,7 +32,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
max_new_tokens=60, max_new_tokens=60,
temperature=0.2, temperature=0.2,
top_p=0.95, top_p=0.95,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )

View File

@ -19,7 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
"Why is the sky blue?", "Why is the sky blue?",
max_new_tokens=10, max_new_tokens=10,
top_p=0.9, top_p=0.9,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )
@ -41,7 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
truncate=5, truncate=5,
typical_p=0.9, typical_p=0.9,
watermark=True, watermark=True,
prefill_details=True, decoder_input_details=True,
seed=0, seed=0,
) )

View File

@ -18,7 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
response = await t5_sharded.generate( response = await t5_sharded.generate(
"Please answer the following question. What is the boiling point of Nitrogen?", "Please answer the following question. What is the boiling point of Nitrogen?",
max_new_tokens=10, max_new_tokens=10,
prefill_details=True, decoder_input_details=True,
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -1,5 +1,5 @@
syrupy syrupy
text-generation==0.5.2 text-generation
pytest pytest
pytest-asyncio==0.17.2 pytest-asyncio==0.17.2
docker docker

View File

@ -126,7 +126,7 @@ pub(crate) struct GenerateParameters {
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
pub prefill_details: bool, pub decoder_input_details: bool,
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0, exclusive_minimum = 0,
@ -156,7 +156,7 @@ fn default_parameters() -> GenerateParameters {
truncate: None, truncate: None,
watermark: false, watermark: false,
details: false, details: false,
prefill_details: false, decoder_input_details: false,
seed: None, seed: None,
} }
} }

View File

@ -201,7 +201,7 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.prefill_details, prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
@ -282,6 +282,7 @@ mod tests {
inputs: "".to_string(), inputs: "".to_string(),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false,
parameters: NextTokenChooserParameters { parameters: NextTokenChooserParameters {
temperature: 0.0, temperature: 0.0,
top_k: 0, top_k: 0,

View File

@ -160,7 +160,7 @@ async fn generate(
add_prompt = Some(req.0.inputs.clone()); add_prompt = Some(req.0.inputs.clone());
} }
let details = req.0.parameters.details; let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
// Inference // Inference
let (response, best_of_responses) = match req.0.parameters.best_of { let (response, best_of_responses) = match req.0.parameters.best_of {
@ -369,7 +369,7 @@ async fn generate_stream(
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else if req.0.parameters.prefill_details { } else if req.0.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream); let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");

View File

@ -145,7 +145,7 @@ impl Validation {
truncate, truncate,
seed, seed,
watermark, watermark,
prefill_details, decoder_input_details,
.. ..
} = request.parameters; } = request.parameters;
@ -262,7 +262,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
prefill_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
@ -337,7 +337,7 @@ pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub prefill_details: bool, pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
} }
@ -354,7 +354,7 @@ pub enum ValidationError {
BestOfSeed, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream, BestOfStream,
#[error("`prefill_details` == true is not supported when streaming tokens")] #[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream, PrefillDetailsStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,