added logit_bias to python client

This commit is contained in:
marcusdunn 2023-08-15 13:35:19 -07:00
parent 247af2d1a8
commit a06b681673
2 changed files with 19 additions and 1 deletions

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
logit_bias: Dict[str, float] = {},
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns: Returns:
Response: generated response Response: generated response
@ -134,6 +137,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
logit_bias=logit_bias,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
logit_bias: Dict[str, float] = {},
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
@ -219,6 +226,7 @@ class Client:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
logit_bias=logit_bias,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
logit_bias: Dict[str, float] = {},
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns: Returns:
Response: generated response Response: generated response
@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
logit_bias=logit_bias,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
logit_bias: Dict[str, float] = {},
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
logit_bias=logit_bias,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List from typing import Optional, List, Dict
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False details: bool = False
# Get decoder input token logprobs and ids # Get decoder input token logprobs and ids
decoder_input_details: bool = False decoder_input_details: bool = False
# Bias generation towards certain tokens
logit_bias: Dict[str, float] = {}
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):