Fixes and generation details arg

This commit is contained in:
Joel Lamy-Poirier 2023-05-05 11:32:38 -04:00
parent b3b1b81982
commit a7c10f710f
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
3 changed files with 20 additions and 10 deletions

View File

@ -74,6 +74,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
details: bool = True,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -110,6 +111,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns: Returns:
Response: generated response Response: generated response
@ -117,7 +120,7 @@ class Client:
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
details=True, details=details,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -160,6 +163,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
details: bool = True,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -194,6 +198,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
@ -201,7 +207,7 @@ class Client:
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=details,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -311,6 +317,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
details: bool = True,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -347,6 +354,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns: Returns:
Response: generated response Response: generated response
@ -354,7 +363,7 @@ class AsyncClient:
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
details=True, details=details,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
@ -395,6 +404,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
details: bool = True,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -429,6 +439,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
details (`bool`):
Return the generation details
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
@ -436,7 +448,7 @@ class AsyncClient:
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=None, best_of=None,
details=True, details=details,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,

View File

@ -200,7 +200,7 @@ class Response(BaseModel):
# Generated text # Generated text
generated_text: str generated_text: str
# Generation details # Generation details
details: Details details: Optional[Details]
# `generate_stream` details # `generate_stream` details

View File

@ -367,9 +367,6 @@ class VectorizedNextTokenChooser:
return values return values
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Only process the last token
scores=scores[: -1, :]
if self.repetition_penalty_t is not None: if self.repetition_penalty_t is not None:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
@ -420,8 +417,6 @@ class VectorizedNextTokenChooser:
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value) scores = scores.masked_fill(indices_to_remove, self.filter_value)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1)
if self.num_do_sample: if self.num_do_sample:
probs = torch.nn.functional.softmax(scores, -1) probs = torch.nn.functional.softmax(scores, -1)
@ -431,6 +426,9 @@ class VectorizedNextTokenChooser:
else: else:
next_token_ids = torch.argmax(scores, dim=-1) next_token_ids = torch.argmax(scores, dim=-1)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1)
return next_token_ids, logprobs return next_token_ids, logprobs
@classmethod @classmethod