Fix top_n_tokens returning non-log probs for some models

This commit is contained in:
Vincent Brouwers 2023-09-14 08:49:35 +00:00
parent c8a01d7591
commit 2a16b4101f
2 changed files with 2 additions and 2 deletions

View File

@ -579,7 +579,7 @@ class CausalLM(Model):
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
torch.log_softmax(logits[:, -1], -1),
)
# Zipped iterator

View File

@ -642,7 +642,7 @@ class Seq2SeqLM(Model):
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
torch.log_softmax(logits[:, -1], -1),
)
# Finished requests