diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cec9ae552..696f0fb23 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -579,7 +579,7 @@ class CausalLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, - torch.softmax(logits[:, -1], -1), + torch.log_softmax(logits[:, -1], -1), ) # Zipped iterator diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 1a7911acd..34932c0b5 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -642,7 +642,7 @@ class Seq2SeqLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, - torch.softmax(logits[:, -1], -1), + torch.log_softmax(logits[:, -1], -1), ) # Finished requests