diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index f96a8a23..96f89f6b 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -12,7 +12,7 @@ }, { "id": 8821, - "text": "Ġrequest", + "text": " request", "logprob": -11.894989 } ], diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 26b6d78f..d15197d0 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -428,8 +428,9 @@ class CausalLM(Model): 1, all_input_ids[1:] ).squeeze(1)[-new_input_length:-1].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.convert_ids_to_tokens( + prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, + clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = PrefillTokens( diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 1409d338..3a4108ab 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -495,13 +495,10 @@ class Seq2SeqLM(Model): # Prefill if stopping_criteria.current_tokens == 1: - prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] - prefill_texts = self.tokenizer.convert_ids_to_tokens( - prefill_token_ids, - skip_special_tokens=False, - ) prefill_tokens = PrefillTokens( - prefill_token_ids, [float("nan")], prefill_texts + [self.tokenizer.bos_token_id], + [float("nan")], + [self.tokenizer.bos_token], ) else: prefill_tokens = None