added logit_bias to python client

This commit is contained in:
marcusdunn 2023-08-15 13:35:19 -07:00
parent 247af2d1a8
commit a06b681673
2 changed files with 19 additions and 1 deletions

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
logit_bias: Dict[str, float] = {},
) -> Response:
"""
Given a prompt, generate the following text
@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns:
Response: generated response
@ -134,6 +137,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
logit_bias=logit_bias,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
logit_bias: Dict[str, float] = {},
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns:
Iterator[StreamResponse]: stream of generated tokens
@ -219,6 +226,7 @@ class Client:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
logit_bias=logit_bias,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
logit_bias: Dict[str, float] = {},
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns:
Response: generated response
@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
logit_bias=logit_bias,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
logit_bias: Dict[str, float] = {},
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
logit_bias=logit_bias,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -1,6 +1,6 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from typing import Optional, List, Dict
from text_generation.errors import ValidationError
@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
# Bias generation towards certain tokens
logit_bias: Dict[str, float] = {}
@validator("best_of")
def valid_best_of(cls, field_value, values):