Ignoring special tokens + updating 1 test case.

This commit is contained in:
Nicolas Patry 2023-09-26 16:35:53 +02:00
parent 76d5bbb0aa
commit 8a16c48595
6 changed files with 9 additions and 9 deletions

View File

@ -168,7 +168,7 @@ def test_seq2seq_lm_generate_token_completion(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].generated_text.text == " a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
@ -186,7 +186,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "a few "
assert generations[1].generated_text.text == " a few "
assert (
generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id

View File

@ -645,6 +645,7 @@ class CausalLM(Model):
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -793,11 +793,6 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE)
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self,
input_ids: torch.Tensor,
@ -1012,6 +1007,7 @@ class FlashCausalLM(Model):
all_input_ids,
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True
)
generated_text = GeneratedText(
output_text,

View File

@ -727,6 +727,7 @@ class IdeficsCausalLM(Model):
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -64,16 +64,17 @@ class Model(ABC):
all_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
skip_special_tokens: bool = False,
) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=False
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
)
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):

View File

@ -714,6 +714,7 @@ class Seq2SeqLM(Model):
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True
)
# Get seed