fix indent issue

fix tokenizer issue

fix tokenizer issue

fix tokenizer issue

fix tokenizer issue

fix tokenizer issue

add 'decode_generated_tokens()' function
This commit is contained in:
bangoz 2023-09-13 16:33:12 +00:00
parent c8a01d7591
commit 00359fcdc5
4 changed files with 23 additions and 6 deletions

View File

@ -641,8 +641,8 @@ class CausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
output_text = self.decode_generated_tokens(
all_input_ids[:, 0], stopping_criteria.current_tokens
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -1008,9 +1008,11 @@ class FlashCausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
output_text = self.decode_generated_tokens(
all_input_ids,
stopping_criteria.current_tokens,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,

View File

@ -728,8 +728,8 @@ class IdeficsCausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
output_text = self.decode_generated_tokens(
all_input_ids[:, 0], stopping_criteria.current_tokens
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -85,6 +85,21 @@ class Model(ABC):
return new_text, read_offset, len(all_input_ids)
else:
return "", prefix_offset, read_offset
def decode_generated_tokens(
self,
all_input_ids: List[int],
num_tokens: int = 0,
) -> str:
# Like in `decode_token()`, the prefix text is necessary only to defeat cleanup algorithms in the decode.
prefix_text = self.tokenizer.decode(
all_input_ids[-num_tokens-1:-num_tokens], skip_special_tokens=False
)
new_text = self.tokenizer.decode(
all_input_ids[-num_tokens-1:], skip_special_tokens=False
)
new_text = new_text[len(prefix_text):]
return new_text
def check_initialized(self):
uninitialized_parameters = []