Fix top_k when k < 0

This commit is contained in:
Nicolas Patry 2023-08-31 20:51:17 +02:00
parent 7d8e5fb284
commit e74a68ee70

View File

@ -363,9 +363,10 @@ def batch_top_tokens(
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()