mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Return more top-n-tokens when probabilities are equal
This commit is contained in:
parent
8515999b1d
commit
65c0d9c19d
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from text_generation_server.utils.tokens import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
FinishReason,
|
||||
batch_top_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -42,3 +44,21 @@ def test_stopping_criteria_max():
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||
|
||||
def test_batch_top_tokens():
|
||||
top_n_tokens = [0, 2, 3, 4, 5]
|
||||
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
||||
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs)
|
||||
|
||||
assert topn_tok_ids[0] == []
|
||||
assert topn_tok_ids[1] == [0, 3]
|
||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||
|
||||
assert topn_tok_logprobs[0] == []
|
||||
assert topn_tok_logprobs[1] == [-1, -2]
|
||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||
|
@ -1,24 +1,20 @@
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.utils.logits_process import (
|
||||
static_warper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousTemperatureLogitsWarper,
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -340,23 +336,54 @@ class HeterogeneousSampling:
|
||||
return self
|
||||
|
||||
|
||||
def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
||||
"""Find the top n most likely tokens for a batch of generations."""
|
||||
top_n_tokens = torch.tensor(top_n_tokens)
|
||||
if top_n_tokens.min() == 0:
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: list[int], logprobs: torch.Tensor
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
When multiple tokens have equal probabilities and they don't all fit, the
|
||||
remaining tokens are also returned.
|
||||
"""
|
||||
# Do this as early as possible to mitigate copy latency
|
||||
top_n_tensor = torch.tensor(top_n_tokens).to(
|
||||
device=logprobs.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Early exit when top_n_tokens is not used
|
||||
if max(top_n_tokens) == 0:
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
||||
|
||||
# Take the topk using the highest requested top_n_tokens.
|
||||
top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True)
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
sorted_top_k = torch.topk(
|
||||
logprobs, k=max(top_n_tokens), dim=1, sorted=True
|
||||
).values # .cpu()
|
||||
nth_highest = torch.gather(
|
||||
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1)
|
||||
)
|
||||
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||
|
||||
# Move all digits into a list at once to prevent multiple GPU syncs
|
||||
# Find the new "fuzzy" top n values
|
||||
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||
|
||||
# Take a new topk for these new max n values
|
||||
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||
|
||||
top_n_ishes = top_n_ishes.tolist()
|
||||
top_indices = top_k.indices.tolist()
|
||||
top_values = top_k.values.tolist()
|
||||
|
||||
return (
|
||||
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
||||
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
||||
[
|
||||
idxs[:n] if req_n > 0 else []
|
||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
[
|
||||
vals[:n] if req_n > 0 else []
|
||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user