mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Fixes and generation details arg
This commit is contained in:
parent
b3b1b81982
commit
a7c10f710f
@ -74,6 +74,7 @@ class Client:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
details: bool = True,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text
|
Given a prompt, generate the following text
|
||||||
@ -110,6 +111,8 @@ class Client:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
details (`bool`):
|
||||||
|
Return the generation details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -117,7 +120,7 @@ class Client:
|
|||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
details=True,
|
details=details,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
@ -160,6 +163,7 @@ class Client:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
details: bool = True,
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens
|
Given a prompt, generate the following stream of tokens
|
||||||
@ -194,6 +198,8 @@ class Client:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
details (`bool`):
|
||||||
|
Return the generation details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterator[StreamResponse]: stream of generated tokens
|
Iterator[StreamResponse]: stream of generated tokens
|
||||||
@ -201,7 +207,7 @@ class Client:
|
|||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=None,
|
best_of=None,
|
||||||
details=True,
|
details=details,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
@ -311,6 +317,7 @@ class AsyncClient:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
details: bool = True,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
@ -347,6 +354,8 @@ class AsyncClient:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
details (`bool`):
|
||||||
|
Return the generation details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -354,7 +363,7 @@ class AsyncClient:
|
|||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
details=True,
|
details=details,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
@ -395,6 +404,7 @@ class AsyncClient:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
details: bool = True,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens asynchronously
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
@ -429,6 +439,8 @@ class AsyncClient:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
details (`bool`):
|
||||||
|
Return the generation details
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
@ -436,7 +448,7 @@ class AsyncClient:
|
|||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=None,
|
best_of=None,
|
||||||
details=True,
|
details=details,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -200,7 +200,7 @@ class Response(BaseModel):
|
|||||||
# Generated text
|
# Generated text
|
||||||
generated_text: str
|
generated_text: str
|
||||||
# Generation details
|
# Generation details
|
||||||
details: Details
|
details: Optional[Details]
|
||||||
|
|
||||||
|
|
||||||
# `generate_stream` details
|
# `generate_stream` details
|
||||||
|
@ -367,9 +367,6 @@ class VectorizedNextTokenChooser:
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Only process the last token
|
|
||||||
scores=scores[: -1, :]
|
|
||||||
|
|
||||||
if self.repetition_penalty_t is not None:
|
if self.repetition_penalty_t is not None:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
@ -420,8 +417,6 @@ class VectorizedNextTokenChooser:
|
|||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
# Compute logprobs
|
|
||||||
logprobs = torch.log_softmax(scores, dim=-1)
|
|
||||||
|
|
||||||
if self.num_do_sample:
|
if self.num_do_sample:
|
||||||
probs = torch.nn.functional.softmax(scores, -1)
|
probs = torch.nn.functional.softmax(scores, -1)
|
||||||
@ -431,6 +426,9 @@ class VectorizedNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
next_token_ids = torch.argmax(scores, dim=-1)
|
next_token_ids = torch.argmax(scores, dim=-1)
|
||||||
|
|
||||||
|
# Compute logprobs
|
||||||
|
logprobs = torch.log_softmax(scores, dim=-1).gather(1, next_token_ids.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
return next_token_ids, logprobs
|
return next_token_ids, logprobs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user