diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 69177d56..7b003f1d 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -363,9 +363,10 @@ def batch_top_tokens( # Find the new "fuzzy" top n values top_n_indices = (logprobs >= nth_highest).nonzero() _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) - + + k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max() # Take a new topk for these new max n values - top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) + top_k = torch.topk(logprobs, k=k, dim=1, sorted=True) top_n_ishes = top_n_ishes.tolist() top_indices = top_k.indices.tolist()