Remove the stripping of the prefix space (and any other mangling that

tokenizers might do).

Superseed #1024

Co-Authored-By: bangoz <ch_xie@pku.edu.cn>
This commit is contained in:
Nicolas Patry 2023-09-26 14:12:25 +00:00
parent ae623b8d2d
commit 76d5bbb0aa
4 changed files with 16 additions and 13 deletions

View File

@ -641,8 +641,10 @@ class CausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -1008,8 +1008,10 @@ class FlashCausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
)
generated_text = GeneratedText(
output_text,

View File

@ -611,11 +611,6 @@ class IdeficsCausalLM(Model):
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
return IdeficsCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self,
input_ids,
@ -728,8 +723,10 @@ class IdeficsCausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -710,8 +710,10 @@ class Seq2SeqLM(Model):
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(
all_decoder_input_ids[-decoder_input_length:]
output_text, _, _ = self.decode_token(
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
)
# Get seed