mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fixes and generation details arg
This commit is contained in:
parent
b3b1b81982
commit
a7c10f710f
@ -74,6 +74,7 @@ class Client:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
details: bool = True,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
@ -110,6 +111,8 @@ class Client:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
details (`bool`):
|
||||
Return the generation details
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -117,7 +120,7 @@ class Client:
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
best_of=best_of,
|
||||
details=True,
|
||||
details=details,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
@ -160,6 +163,7 @@ class Client:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
details: bool = True,
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens
|
||||
@ -194,6 +198,8 @@ class Client:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
details (`bool`):
|
||||
Return the generation details
|
||||
|
||||
Returns:
|
||||
Iterator[StreamResponse]: stream of generated tokens
|
||||
@ -201,7 +207,7 @@ class Client:
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
best_of=None,
|
||||
details=True,
|
||||
details=details,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
@ -311,6 +317,7 @@ class AsyncClient:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
details: bool = True,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
@ -347,6 +354,8 @@ class AsyncClient:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
details (`bool`):
|
||||
Return the generation details
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -354,7 +363,7 @@ class AsyncClient:
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
best_of=best_of,
|
||||
details=True,
|
||||
details=details,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
@ -395,6 +404,7 @@ class AsyncClient:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
details: bool = True,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
@ -429,6 +439,8 @@ class AsyncClient:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
details (`bool`):
|
||||
Return the generation details
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
@ -436,7 +448,7 @@ class AsyncClient:
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
best_of=None,
|
||||
details=True,
|
||||
details=details,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
|
@ -200,7 +200,7 @@ class Response(BaseModel):
|
||||
# Generated text
|
||||
generated_text: str
|
||||
# Generation details
|
||||
details: Details
|
||||
details: Optional[Details]
|
||||
|
||||
|
||||
# `generate_stream` details
|
||||
|
@ -367,9 +367,6 @@ class VectorizedNextTokenChooser:
|
||||
return values
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Only process the last token
|
||||
scores=scores[: -1, :]
|
||||
|
||||
if self.repetition_penalty_t is not None:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
@ -420,8 +417,6 @@ class VectorizedNextTokenChooser:
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, dim=-1)
|
||||
|
||||
if self.num_do_sample:
|
||||
probs = torch.nn.functional.softmax(scores, -1)
|
||||
@ -431,6 +426,9 @@ class VectorizedNextTokenChooser:
|
||||
else:
|
||||
next_token_ids = torch.argmax(scores, dim=-1)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1)
|
||||
|
||||
return next_token_ids, logprobs
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user