fix(server): avoid errors for very small top_p values

See https://github.com/huggingface/transformers/pull/24453.

I didn't add validation to the __init__ method since it's not done for other values/warpers.
This commit is contained in:
Nick Hill 2023-07-04 10:59:40 -07:00
parent 2a101207d4
commit 8a7bfcd571

View File

@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep
# Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(