mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Share computation for top-n-token decoding
This commit is contained in:
parent
f809f179dc
commit
494e6b1c61
@ -949,6 +949,7 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
batch_top_token_ids,
|
||||
@ -965,26 +966,31 @@ class FlashCausalLM(Model):
|
||||
all_input_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = []
|
||||
for token_id, token_logprob in zip(top_token_ids, top_token_logprobs):
|
||||
tok_itm = token_id
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=token_id,
|
||||
token_logprob=token_logprob,
|
||||
token_text=self.decode_token(
|
||||
all_input_ids=all_input_ids + [tok_itm],
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)[0],
|
||||
token_is_special=tok_itm in self.all_special_ids,
|
||||
)
|
||||
|
||||
if top_n_tokens > 0:
|
||||
top_token_texts = self.decode_tokens(
|
||||
input_ids=all_input_ids,
|
||||
new_input_ids=top_token_ids,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
|
||||
tok_itm = token_id
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=token_id,
|
||||
token_logprob=token_logprob,
|
||||
token_text=top_token_text,
|
||||
token_is_special=tok_itm in self.all_special_ids,
|
||||
)
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
|
@ -86,6 +86,37 @@ class Model(ABC):
|
||||
else:
|
||||
return "", prefix_offset, read_offset
|
||||
|
||||
def decode_tokens(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
new_input_ids: List[int],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
) -> Tuple[str, int, int]:
|
||||
"""Version of decode_token that supports multiple new tokens for the same prefix."""
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||
# which decide to add a space or not depending on the surrounding ids.
|
||||
prefix_text = self.tokenizer.decode(
|
||||
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||
)
|
||||
|
||||
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids]
|
||||
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False)
|
||||
|
||||
results = []
|
||||
for new_text in new_texts:
|
||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
new_text = new_text[len(prefix_text) :]
|
||||
results.append((new_text, read_offset, len(input_ids) + 1))
|
||||
else:
|
||||
results.append(("", prefix_offset, read_offset))
|
||||
return results
|
||||
|
||||
def check_initialized(self):
|
||||
uninitialized_parameters = []
|
||||
for n, p in self.model.named_parameters():
|
||||
|
@ -345,7 +345,7 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
||||
"""Find the top n most likely tokens for a batch of generations."""
|
||||
top_n_tokens = torch.tensor(top_n_tokens)
|
||||
if top_n_tokens.min() == 0:
|
||||
return [], []
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
||||
|
Loading…
Reference in New Issue
Block a user