From a7c10f710f6e212e368d4baeaf03e059f36cf1d9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 May 2023 11:32:38 -0400 Subject: [PATCH] Fixes and generation details arg --- clients/python/text_generation/client.py | 20 +++++++++++++++---- clients/python/text_generation/types.py | 2 +- .../models/vectorized_causal_lm.py | 8 +++----- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 8b8742fc..bea283ca 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -74,6 +74,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + details: bool = True, ) -> Response: """ Given a prompt, generate the following text @@ -110,6 +111,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + details (`bool`): + Return the generation details Returns: Response: generated response @@ -117,7 +120,7 @@ class Client: # Validate parameters parameters = Parameters( best_of=best_of, - details=True, + details=details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -160,6 +163,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + details: bool = True, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -194,6 +198,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + details (`bool`): + Return the generation details Returns: Iterator[StreamResponse]: stream of generated tokens @@ -201,7 +207,7 @@ class Client: # Validate parameters parameters = Parameters( best_of=None, - details=True, + details=details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -311,6 +317,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + details: bool = True, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -347,6 +354,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + details (`bool`): + Return the generation details Returns: Response: generated response @@ -354,7 +363,7 @@ class AsyncClient: # Validate parameters parameters = Parameters( best_of=best_of, - details=True, + details=details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -395,6 +404,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + details: bool = True, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -429,6 +439,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + details (`bool`): + Return the generation details Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -436,7 +448,7 @@ class AsyncClient: # Validate parameters parameters = Parameters( best_of=None, - details=True, + details=details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index f3f9dcb5..96494305 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -200,7 +200,7 @@ class Response(BaseModel): # Generated text generated_text: str # Generation details - details: Details + details: Optional[Details] # `generate_stream` details diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 2c8477a9..35d037a0 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -367,9 +367,6 @@ class VectorizedNextTokenChooser: return values def __call__(self, input_ids, scores): - # Only process the last token - scores=scores[: -1, :] - if self.repetition_penalty_t is not None: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability @@ -420,8 +417,6 @@ class VectorizedNextTokenChooser: indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) - # Compute logprobs - logprobs = torch.log_softmax(scores, dim=-1) if self.num_do_sample: probs = torch.nn.functional.softmax(scores, -1) @@ -431,6 +426,9 @@ class VectorizedNextTokenChooser: else: next_token_ids = torch.argmax(scores, dim=-1) + # Compute logprobs + logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1) + return next_token_ids, logprobs @classmethod