From 1f7042d16549a2d11c977c3e2476808c0e563625 Mon Sep 17 00:00:00 2001 From: gduhamel Date: Tue, 23 Jan 2024 21:05:13 +0100 Subject: [PATCH] fix error if top_n_tokens is 0 or null --- server/text_generation_server/models/causal_lm.py | 2 +- server/text_generation_server/models/flash_causal_lm.py | 2 +- server/text_generation_server/models/idefics_causal_lm.py | 2 +- server/text_generation_server/models/seq2seq_lm.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 151be1bd..aec7c073 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -719,7 +719,7 @@ class CausalLM(Model): [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, - [top_tokens], + [top_tokens] if top_tokens is not None else None, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bafe30d9..9993954b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1089,7 +1089,7 @@ class FlashCausalLM(Model): [nid in self.all_special_ids for nid in _next_token_ids], ), generated_text, - [top_tokens], + [top_tokens] if top_tokens is not None else None, ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index e1972b11..f4aeb9da 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -809,7 +809,7 @@ class IdeficsCausalLM(Model): [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, - [top_tokens], + [top_tokens] if top_tokens is not None else None, ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 191c8c8a..37b22609 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -773,7 +773,7 @@ class Seq2SeqLM(Model): [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, - [top_tokens], + [top_tokens] if top_tokens is not None else None, ) generations.append(generation)