From a06b681673de3539c3382807d7f0969d7603fc8b Mon Sep 17 00:00:00 2001 From: marcusdunn Date: Tue, 15 Aug 2023 13:35:19 -0700 Subject: [PATCH] added `logit_bias` to python client --- clients/python/text_generation/client.py | 16 ++++++++++++++++ clients/python/text_generation/types.py | 4 +++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bf045d47..6399548b 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -75,6 +75,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + logit_bias: Dict[str, float] = {}, ) -> Response: """ Given a prompt, generate the following text @@ -113,6 +114,8 @@ class Client: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + logit_bias (`Dict[str, float]`): + Bias generation towards certain tokens. Returns: Response: generated response @@ -134,6 +137,7 @@ class Client: typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, + logit_bias=logit_bias, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -164,6 +168,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + logit_bias: Dict[str, float] = {}, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -198,6 +203,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + logit_bias (`Dict[str, float]`): + Bias generation towards certain tokens. Returns: Iterator[StreamResponse]: stream of generated tokens @@ -219,6 +226,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, + logit_bias=logit_bias, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -317,6 +325,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + logit_bias: Dict[str, float] = {}, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -355,6 +364,8 @@ class AsyncClient: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + logit_bias (`Dict[str, float]`): + Bias generation towards certain tokens. Returns: Response: generated response @@ -376,6 +387,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + logit_bias=logit_bias, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -404,6 +416,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + logit_bias: Dict[str, float] = {}, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -438,6 +451,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + logit_bias (`Dict[str, float]`): + Bias generation towards certain tokens. Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -459,6 +474,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + logit_bias=logit_bias, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 548f0b63..62e79a0b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List +from typing import Optional, List, Dict from text_generation.errors import ValidationError @@ -39,6 +39,8 @@ class Parameters(BaseModel): details: bool = False # Get decoder input token logprobs and ids decoder_input_details: bool = False + # Bias generation towards certain tokens + logit_bias: Dict[str, float] = {} @validator("best_of") def valid_best_of(cls, field_value, values):