mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Ignoring special tokens + updating 1 test case.
This commit is contained in:
parent
76d5bbb0aa
commit
8a16c48595
@ -168,7 +168,7 @@ def test_seq2seq_lm_generate_token_completion(
|
|||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert generations[0].generated_text.text == "a few weeks"
|
assert generations[0].generated_text.text == " a few weeks"
|
||||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generations[0].generated_text.generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
@ -186,7 +186,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
assert generations[1].generated_text.text == "a few "
|
assert generations[1].generated_text.text == " a few "
|
||||||
assert (
|
assert (
|
||||||
generations[1].request_id
|
generations[1].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||||
|
@ -645,6 +645,7 @@ class CausalLM(Model):
|
|||||||
all_input_ids[:, 0],
|
all_input_ids[:, 0],
|
||||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -793,11 +793,6 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -1012,6 +1007,7 @@ class FlashCausalLM(Model):
|
|||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
|
@ -727,6 +727,7 @@ class IdeficsCausalLM(Model):
|
|||||||
all_input_ids[:, 0],
|
all_input_ids[:, 0],
|
||||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -64,16 +64,17 @@ class Model(ABC):
|
|||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
prefix_offset: int = 0,
|
prefix_offset: int = 0,
|
||||||
read_offset: int = 0,
|
read_offset: int = 0,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
) -> Tuple[str, int, int]:
|
) -> Tuple[str, int, int]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
# which decide to add a space or not depending on the surrounding ids.
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
prefix_text = self.tokenizer.decode(
|
prefix_text = self.tokenizer.decode(
|
||||||
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
new_text = self.tokenizer.decode(
|
new_text = self.tokenizer.decode(
|
||||||
all_input_ids[prefix_offset:], skip_special_tokens=False
|
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
|
@ -714,6 +714,7 @@ class Seq2SeqLM(Model):
|
|||||||
all_decoder_input_ids,
|
all_decoder_input_ids,
|
||||||
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
||||||
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||||
|
skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
|
Loading…
Reference in New Issue
Block a user