Fixes and generation details arg

This commit is contained in:
Joel Lamy-Poirier 2023-05-05 11:32:38 -04:00
parent b3b1b81982
commit a7c10f710f
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
3 changed files with 20 additions and 10 deletions

View File

@ -74,6 +74,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
details: bool = True,
) -> Response:
"""
Given a prompt, generate the following text
@ -110,6 +111,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns:
Response: generated response
@ -117,7 +120,7 @@ class Client:
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
details=details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
@ -160,6 +163,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
details: bool = True,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
@ -194,6 +198,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns:
Iterator[StreamResponse]: stream of generated tokens
@ -201,7 +207,7 @@ class Client:
# Validate parameters
parameters = Parameters(
best_of=None,
details=True,
details=details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
@ -311,6 +317,7 @@ class AsyncClient:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
details: bool = True,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -347,6 +354,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns:
Response: generated response
@ -354,7 +363,7 @@ class AsyncClient:
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
details=details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
@ -395,6 +404,7 @@ class AsyncClient:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
details: bool = True,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
@ -429,6 +439,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
@ -436,7 +448,7 @@ class AsyncClient:
# Validate parameters
parameters = Parameters(
best_of=None,
details=True,
details=details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,

View File

@ -200,7 +200,7 @@ class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Details
details: Optional[Details]
# `generate_stream` details

View File

@ -367,9 +367,6 @@ class VectorizedNextTokenChooser:
return values
def __call__(self, input_ids, scores):
# Only process the last token
scores=scores[: -1, :]
if self.repetition_penalty_t is not None:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
@ -420,8 +417,6 @@ class VectorizedNextTokenChooser:
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1)
if self.num_do_sample:
probs = torch.nn.functional.softmax(scores, -1)
@ -431,6 +426,9 @@ class VectorizedNextTokenChooser:
else:
next_token_ids = torch.argmax(scores, dim=-1)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1)
return next_token_ids, logprobs
@classmethod