mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Remove the stripping of the prefix space (and any other mangling that
tokenizers might do). Superseed #1024 Co-Authored-By: bangoz <ch_xie@pku.edu.cn>
This commit is contained in:
parent
ae623b8d2d
commit
76d5bbb0aa
@ -641,8 +641,10 @@ class CausalLM(Model):
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
|
@ -1008,8 +1008,10 @@ class FlashCausalLM(Model):
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :]
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text,
|
||||
|
@ -611,11 +611,6 @@ class IdeficsCausalLM(Model):
|
||||
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
||||
return IdeficsCausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@ -728,8 +723,10 @@ class IdeficsCausalLM(Model):
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
|
@ -710,8 +710,10 @@ class Seq2SeqLM(Model):
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(
|
||||
all_decoder_input_ids[-decoder_input_length:]
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_decoder_input_ids,
|
||||
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
||||
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||
)
|
||||
|
||||
# Get seed
|
||||
|
Loading…
Reference in New Issue
Block a user