From e74a68ee70487d2f3b39c1fc0d60441b5ccc0c65 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 31 Aug 2023 20:51:17 +0200 Subject: [PATCH] Fix top_k when k < 0 --- server/text_generation_server/utils/tokens.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 69177d56..7b003f1d 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -363,9 +363,10 @@ def batch_top_tokens( # Find the new "fuzzy" top n values top_n_indices = (logprobs >= nth_highest).nonzero() _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) - + + k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max() # Take a new topk for these new max n values - top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) + top_k = torch.topk(logprobs, k=k, dim=1, sorted=True) top_n_ishes = top_n_ishes.tolist() top_indices = top_k.indices.tolist()