Fix top_k when k < 0

This commit is contained in:
Nicolas Patry 2023-08-31 20:51:17 +02:00
parent 7d8e5fb284
commit e74a68ee70

View File

@ -364,8 +364,9 @@ def batch_top_tokens(
top_n_indices = (logprobs >= nth_highest).nonzero() top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
# Take a new topk for these new max n values # Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist() top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist() top_indices = top_k.indices.tolist()