mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Ignoring special tokens + updating 1 test case.
This commit is contained in:
parent
76d5bbb0aa
commit
8a16c48595
@ -168,7 +168,7 @@ def test_seq2seq_lm_generate_token_completion(
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "a few weeks"
|
||||
assert generations[0].generated_text.text == " a few weeks"
|
||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
@ -186,7 +186,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "a few "
|
||||
assert generations[1].generated_text.text == " a few "
|
||||
assert (
|
||||
generations[1].request_id
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||
|
@ -645,6 +645,7 @@ class CausalLM(Model):
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
|
@ -793,11 +793,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1012,6 +1007,7 @@ class FlashCausalLM(Model):
|
||||
all_input_ids,
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text,
|
||||
|
@ -727,6 +727,7 @@ class IdeficsCausalLM(Model):
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
|
@ -64,16 +64,17 @@ class Model(ABC):
|
||||
all_input_ids: List[int],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[str, int, int]:
|
||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||
# which decide to add a space or not depending on the surrounding ids.
|
||||
prefix_text = self.tokenizer.decode(
|
||||
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||
all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
new_text = self.tokenizer.decode(
|
||||
all_input_ids[prefix_offset:], skip_special_tokens=False
|
||||
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||
|
@ -714,6 +714,7 @@ class Seq2SeqLM(Model):
|
||||
all_decoder_input_ids,
|
||||
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
||||
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Get seed
|
||||
|
Loading…
Reference in New Issue
Block a user