mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Return more top-n-tokens when probabilities are equal
This commit is contained in:
parent
8515999b1d
commit
65c0d9c19d
@ -1,7 +1,9 @@
|
|||||||
|
import torch
|
||||||
from text_generation_server.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
batch_top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -42,3 +44,21 @@ def test_stopping_criteria_max():
|
|||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
def test_batch_top_tokens():
|
||||||
|
top_n_tokens = [0, 2, 3, 4, 5]
|
||||||
|
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
||||||
|
|
||||||
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs)
|
||||||
|
|
||||||
|
assert topn_tok_ids[0] == []
|
||||||
|
assert topn_tok_ids[1] == [0, 3]
|
||||||
|
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||||
|
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||||
|
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||||
|
|
||||||
|
assert topn_tok_logprobs[0] == []
|
||||||
|
assert topn_tok_logprobs[1] == [-1, -2]
|
||||||
|
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||||
|
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||||
|
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||||
|
@ -1,24 +1,20 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, List, Tuple, Optional
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
RepetitionPenaltyLogitsProcessor,
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
static_warper,
|
HeterogeneousProcessorWrapper,
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
HeterogeneousTemperatureLogitsWarper,
|
HeterogeneousTemperatureLogitsWarper,
|
||||||
HeterogeneousTopKLogitsWarper,
|
HeterogeneousTopKLogitsWarper,
|
||||||
HeterogeneousTopPLogitsWarper,
|
HeterogeneousTopPLogitsWarper,
|
||||||
HeterogeneousTypicalLogitsWarper,
|
HeterogeneousTypicalLogitsWarper,
|
||||||
HeterogeneousProcessorWrapper,
|
static_warper,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -340,23 +336,54 @@ class HeterogeneousSampling:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
def batch_top_tokens(
|
||||||
"""Find the top n most likely tokens for a batch of generations."""
|
top_n_tokens: list[int], logprobs: torch.Tensor
|
||||||
top_n_tokens = torch.tensor(top_n_tokens)
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||||
if top_n_tokens.min() == 0:
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
|
When multiple tokens have equal probabilities and they don't all fit, the
|
||||||
|
remaining tokens are also returned.
|
||||||
|
"""
|
||||||
|
# Do this as early as possible to mitigate copy latency
|
||||||
|
top_n_tensor = torch.tensor(top_n_tokens).to(
|
||||||
|
device=logprobs.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Early exit when top_n_tokens is not used
|
||||||
|
if max(top_n_tokens) == 0:
|
||||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||||
|
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
||||||
|
|
||||||
# Take the topk using the highest requested top_n_tokens.
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True)
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
|
sorted_top_k = torch.topk(
|
||||||
|
logprobs, k=max(top_n_tokens), dim=1, sorted=True
|
||||||
|
).values # .cpu()
|
||||||
|
nth_highest = torch.gather(
|
||||||
|
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1)
|
||||||
|
)
|
||||||
|
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||||
|
|
||||||
# Move all digits into a list at once to prevent multiple GPU syncs
|
# Find the new "fuzzy" top n values
|
||||||
|
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||||
|
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||||
|
|
||||||
|
# Take a new topk for these new max n values
|
||||||
|
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||||
|
|
||||||
|
top_n_ishes = top_n_ishes.tolist()
|
||||||
top_indices = top_k.indices.tolist()
|
top_indices = top_k.indices.tolist()
|
||||||
top_values = top_k.values.tolist()
|
top_values = top_k.values.tolist()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
[
|
||||||
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
idxs[:n] if req_n > 0 else []
|
||||||
|
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
vals[:n] if req_n > 0 else []
|
||||||
|
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user