revert to batch_decode

This commit is contained in:
OlivierDehaene 2023-02-24 15:34:20 +01:00
parent ed59f16b96
commit 4698368a1a
3 changed files with 6 additions and 8 deletions

View File

@ -12,7 +12,7 @@
},
{
"id": 8821,
"text": "Ġrequest",
"text": " request",
"logprob": -11.894989
}
],

View File

@ -428,8 +428,9 @@ class CausalLM(Model):
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.convert_ids_to_tokens(
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(

View File

@ -495,13 +495,10 @@ class Seq2SeqLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.convert_ids_to_tokens(
prefill_token_ids,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
)
else:
prefill_tokens = None