From 2a16b4101f5c265ba7c759d7d04d8881c07acf48 Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Thu, 14 Sep 2023 08:49:35 +0000 Subject: [PATCH] Fix top_n_tokens returning non-log probs for some models --- server/text_generation_server/models/causal_lm.py | 2 +- server/text_generation_server/models/seq2seq_lm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 4e338263..e4496ee6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -579,7 +579,7 @@ class CausalLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, - torch.softmax(logits[:, -1], -1), + torch.log_softmax(logits[:, -1], -1), ) # Zipped iterator diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 361453fb..1194f289 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -642,7 +642,7 @@ class Seq2SeqLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, - torch.softmax(logits[:, -1], -1), + torch.log_softmax(logits[:, -1], -1), ) # Finished requests