Share computation for top-n-token decoding

This commit is contained in:
Vincent Brouwers 2023-07-25 14:55:32 +00:00
parent f809f179dc
commit 494e6b1c61
3 changed files with 51 additions and 14 deletions

View File

@ -949,6 +949,7 @@ class FlashCausalLM(Model):
batch.all_input_ids, batch.all_input_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids, next_token_ids,
next_token_logprobs, next_token_logprobs,
batch_top_token_ids, batch_top_token_ids,
@ -965,26 +966,31 @@ class FlashCausalLM(Model):
all_input_ids, all_input_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens,
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = [] top_tokens = []
for token_id, token_logprob in zip(top_token_ids, top_token_logprobs):
tok_itm = token_id if top_n_tokens > 0:
top_tokens.append( top_token_texts = self.decode_tokens(
TopToken( input_ids=all_input_ids,
token_id=token_id, new_input_ids=top_token_ids,
token_logprob=token_logprob, prefix_offset=prefix_offset,
token_text=self.decode_token( read_offset=read_offset,
all_input_ids=all_input_ids + [tok_itm],
prefix_offset=prefix_offset,
read_offset=read_offset,
)[0],
token_is_special=tok_itm in self.all_special_ids,
)
) )
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)

View File

@ -86,6 +86,37 @@ class Model(ABC):
else: else:
return "", prefix_offset, read_offset return "", prefix_offset, read_offset
def decode_tokens(
self,
input_ids: List[int],
new_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
) -> Tuple[str, int, int]:
"""Version of decode_token that supports multiple new tokens for the same prefix."""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
input_ids[prefix_offset:read_offset], skip_special_tokens=False
)
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids]
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False)
results = []
for new_text in new_texts:
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text) :]
results.append((new_text, read_offset, len(input_ids) + 1))
else:
results.append(("", prefix_offset, read_offset))
return results
def check_initialized(self): def check_initialized(self):
uninitialized_parameters = [] uninitialized_parameters = []
for n, p in self.model.named_parameters(): for n, p in self.model.named_parameters():

View File

@ -345,7 +345,7 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
"""Find the top n most likely tokens for a batch of generations.""" """Find the top n most likely tokens for a batch of generations."""
top_n_tokens = torch.tensor(top_n_tokens) top_n_tokens = torch.tensor(top_n_tokens)
if top_n_tokens.min() == 0: if top_n_tokens.min() == 0:
return [], [] return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1)) top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))