mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fix top_k when k < 0
This commit is contained in:
parent
7d8e5fb284
commit
e74a68ee70
@ -363,9 +363,10 @@ def batch_top_tokens(
|
||||
# Find the new "fuzzy" top n values
|
||||
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||
|
||||
|
||||
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
|
||||
# Take a new topk for these new max n values
|
||||
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
|
||||
|
||||
top_n_ishes = top_n_ishes.tolist()
|
||||
top_indices = top_k.indices.tolist()
|
||||
|
Loading…
Reference in New Issue
Block a user