From 65c0d9c19d7683fd0b7945c6548371aa071591a5 Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Fri, 28 Jul 2023 14:21:11 +0000 Subject: [PATCH] Return more top-n-tokens when probabilities are equal --- server/tests/utils/test_tokens.py | 20 ++++++ server/text_generation_server/utils/tokens.py | 67 +++++++++++++------ 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index da0006e4..46b1220f 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,7 +1,9 @@ +import torch from text_generation_server.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, FinishReason, + batch_top_tokens, ) @@ -42,3 +44,21 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) + +def test_batch_top_tokens(): + top_n_tokens = [0, 2, 3, 4, 5] + inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5) + + topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs) + + assert topn_tok_ids[0] == [] + assert topn_tok_ids[1] == [0, 3] + assert topn_tok_ids[2] == [0, 3, 1, 4] + assert topn_tok_ids[3] == [0, 3, 1, 4] + assert topn_tok_ids[4] == [0, 3, 1, 4, 2] + + assert topn_tok_logprobs[0] == [] + assert topn_tok_logprobs[1] == [-1, -2] + assert topn_tok_logprobs[2] == [-1, -2, -3, -3] + assert topn_tok_logprobs[3] == [-1, -2, -3, -3] + assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index db7f9510..fe40b338 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,24 +1,20 @@ import re -from typing import Callable, List, Tuple, Optional +from typing import Callable, List, Optional, Tuple + import torch - -from transformers import ( - RepetitionPenaltyLogitsProcessor, - PreTrainedTokenizerBase, -) - from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason -from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.logits_process import ( - static_warper, + HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, - HeterogeneousProcessorWrapper, + static_warper, ) +from text_generation_server.utils.watermark import WatermarkLogitsProcessor +from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor class NextTokenChooser: @@ -340,23 +336,54 @@ class HeterogeneousSampling: return self -def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor): - """Find the top n most likely tokens for a batch of generations.""" - top_n_tokens = torch.tensor(top_n_tokens) - if top_n_tokens.min() == 0: +def batch_top_tokens( + top_n_tokens: list[int], logprobs: torch.Tensor +) -> Tuple[List[List[int]], List[List[float]]]: + """Find the top n most likely tokens for a batch of generations. + + When multiple tokens have equal probabilities and they don't all fit, the + remaining tokens are also returned. + """ + # Do this as early as possible to mitigate copy latency + top_n_tensor = torch.tensor(top_n_tokens).to( + device=logprobs.device, non_blocking=True + ) + + # Early exit when top_n_tokens is not used + if max(top_n_tokens) == 0: return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) # Ensure top_n doesn't exceed vocab size - top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1)) + top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] - # Take the topk using the highest requested top_n_tokens. - top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True) + # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 + # Sorted topk is faster than torch.sort() since we only need a small subset + sorted_top_k = torch.topk( + logprobs, k=max(top_n_tokens), dim=1, sorted=True + ).values # .cpu() + nth_highest = torch.gather( + sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1) + ) + nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min - # Move all digits into a list at once to prevent multiple GPU syncs + # Find the new "fuzzy" top n values + top_n_indices = (logprobs >= nth_highest).nonzero() + _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) + + # Take a new topk for these new max n values + top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) + + top_n_ishes = top_n_ishes.tolist() top_indices = top_k.indices.tolist() top_values = top_k.values.tolist() return ( - [idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)], - [vals[:n] for vals, n in zip(top_values, top_n_tokens)], + [ + idxs[:n] if req_n > 0 else [] + for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens) + ], + [ + vals[:n] if req_n > 0 else [] + for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens) + ], )