mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Remove the stripping of the prefix space (and any other mangling that
tokenizers might do). Superseed #1024 Co-Authored-By: bangoz <ch_xie@pku.edu.cn>
This commit is contained in:
parent
ae623b8d2d
commit
76d5bbb0aa
@ -641,8 +641,10 @@ class CausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -1008,8 +1008,10 @@ class FlashCausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :]
|
all_input_ids,
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
|
@ -611,11 +611,6 @@ class IdeficsCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
||||||
return IdeficsCausalLMBatch
|
return IdeficsCausalLMBatch
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -728,8 +723,10 @@ class IdeficsCausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -710,8 +710,10 @@ class Seq2SeqLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output_text = self.decode(
|
output_text, _, _ = self.decode_token(
|
||||||
all_decoder_input_ids[-decoder_input_length:]
|
all_decoder_input_ids,
|
||||||
|
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
||||||
|
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
|
Loading…
Reference in New Issue
Block a user