mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix indent issue
fix tokenizer issue fix tokenizer issue fix tokenizer issue fix tokenizer issue fix tokenizer issue add 'decode_generated_tokens()' function
This commit is contained in:
parent
c8a01d7591
commit
00359fcdc5
@ -641,8 +641,8 @@ class CausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode_generated_tokens(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[:, 0], stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -1008,9 +1008,11 @@ class FlashCausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode_generated_tokens(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :]
|
all_input_ids,
|
||||||
|
stopping_criteria.current_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
stopping_criteria.current_tokens,
|
stopping_criteria.current_tokens,
|
||||||
|
@ -728,8 +728,8 @@ class IdeficsCausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode_generated_tokens(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[:, 0], stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -85,6 +85,21 @@ class Model(ABC):
|
|||||||
return new_text, read_offset, len(all_input_ids)
|
return new_text, read_offset, len(all_input_ids)
|
||||||
else:
|
else:
|
||||||
return "", prefix_offset, read_offset
|
return "", prefix_offset, read_offset
|
||||||
|
|
||||||
|
def decode_generated_tokens(
|
||||||
|
self,
|
||||||
|
all_input_ids: List[int],
|
||||||
|
num_tokens: int = 0,
|
||||||
|
) -> str:
|
||||||
|
# Like in `decode_token()`, the prefix text is necessary only to defeat cleanup algorithms in the decode.
|
||||||
|
prefix_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[-num_tokens-1:-num_tokens], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
new_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[-num_tokens-1:], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
new_text = new_text[len(prefix_text):]
|
||||||
|
return new_text
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
|
Loading…
Reference in New Issue
Block a user