fix(server): avoid errors for very small top_p values

See https://github.com/huggingface/transformers/pull/24453.

I didn't add validation to the __init__ method since it's not done for other values/warpers.
This commit is contained in:
Nick Hill 2023-07-04 10:59:40 -07:00
parent 2a101207d4
commit 8a7bfcd571

View File

@ -189,7 +189,6 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0