Return more top-n-tokens when probabilities are equal

This commit is contained in:
Vincent Brouwers 2023-07-28 14:21:11 +00:00
parent 50d05fa20d
commit 95d0fba7de
2 changed files with 67 additions and 20 deletions

View File

@ -1,7 +1,9 @@
import torch
from text_generation_server.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,
batch_top_tokens,
) )
@ -42,3 +44,21 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]

View File

@ -1,24 +1,20 @@
import re import re
from typing import Callable, List, Tuple, Optional from typing import Callable, List, Optional, Tuple
import torch import torch
from transformers import (
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
static_warper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
@ -340,23 +336,54 @@ class HeterogeneousSampling:
return self return self
def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor): def batch_top_tokens(
"""Find the top n most likely tokens for a batch of generations.""" top_n_tokens: list[int], logprobs: torch.Tensor
top_n_tokens = torch.tensor(top_n_tokens) ) -> Tuple[List[List[int]], List[List[float]]]:
if top_n_tokens.min() == 0: """Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
# Do this as early as possible to mitigate copy latency
top_n_tensor = torch.tensor(top_n_tokens).to(
device=logprobs.device, non_blocking=True
)
# Early exit when top_n_tokens is not used
if max(top_n_tokens) == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1)) top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Take the topk using the highest requested top_n_tokens. # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True) # Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(
logprobs, k=max(top_n_tokens), dim=1, sorted=True
).values # .cpu()
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Move all digits into a list at once to prevent multiple GPU syncs # Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist() top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist() top_values = top_k.values.tolist()
return ( return (
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)], [
[vals[:n] for vals, n in zip(top_values, top_n_tokens)], idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
) )