diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7b10256c..151be1bd 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -719,7 +719,7 @@ class CausalLM(Model): [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, - top_tokens, + [top_tokens], ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 930082cd..bafe30d9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1089,7 +1089,7 @@ class FlashCausalLM(Model): [nid in self.all_special_ids for nid in _next_token_ids], ), generated_text, - top_tokens, + [top_tokens], ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2f28688d..e1972b11 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -809,7 +809,7 @@ class IdeficsCausalLM(Model): [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, - top_tokens, + [top_tokens], ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f2e4cec6..191c8c8a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -773,7 +773,7 @@ class Seq2SeqLM(Model): [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, - top_tokens, + [top_tokens], ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index f85f27e5..8f926cfb 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -95,5 +95,7 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, - top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, + top_tokens=([top_token.to_pb() for top_token in self.top_tokens] + if self.top_tokens is not None + else None), )