mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Share computation for top-n-token decoding
This commit is contained in:
parent
f809f179dc
commit
494e6b1c61
@ -949,6 +949,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
|
batch.top_n_tokens,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
batch_top_token_ids,
|
batch_top_token_ids,
|
||||||
@ -965,26 +966,31 @@ class FlashCausalLM(Model):
|
|||||||
all_input_ids,
|
all_input_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
|
top_n_tokens,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = []
|
top_tokens = []
|
||||||
for token_id, token_logprob in zip(top_token_ids, top_token_logprobs):
|
|
||||||
tok_itm = token_id
|
if top_n_tokens > 0:
|
||||||
top_tokens.append(
|
top_token_texts = self.decode_tokens(
|
||||||
TopToken(
|
input_ids=all_input_ids,
|
||||||
token_id=token_id,
|
new_input_ids=top_token_ids,
|
||||||
token_logprob=token_logprob,
|
prefix_offset=prefix_offset,
|
||||||
token_text=self.decode_token(
|
read_offset=read_offset,
|
||||||
all_input_ids=all_input_ids + [tok_itm],
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
)[0],
|
|
||||||
token_is_special=tok_itm in self.all_special_ids,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
|
||||||
|
tok_itm = token_id
|
||||||
|
top_tokens.append(
|
||||||
|
TopToken(
|
||||||
|
token_id=token_id,
|
||||||
|
token_logprob=token_logprob,
|
||||||
|
token_text=top_token_text,
|
||||||
|
token_is_special=tok_itm in self.all_special_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
@ -86,6 +86,37 @@ class Model(ABC):
|
|||||||
else:
|
else:
|
||||||
return "", prefix_offset, read_offset
|
return "", prefix_offset, read_offset
|
||||||
|
|
||||||
|
def decode_tokens(
|
||||||
|
self,
|
||||||
|
input_ids: List[int],
|
||||||
|
new_input_ids: List[int],
|
||||||
|
prefix_offset: int = 0,
|
||||||
|
read_offset: int = 0,
|
||||||
|
) -> Tuple[str, int, int]:
|
||||||
|
"""Version of decode_token that supports multiple new tokens for the same prefix."""
|
||||||
|
|
||||||
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
|
prefix_text = self.tokenizer.decode(
|
||||||
|
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
|
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids]
|
||||||
|
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for new_text in new_texts:
|
||||||
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
|
# from byte fallback tokenization.
|
||||||
|
# If it's in the middle, it's probably a real invalid id generated
|
||||||
|
# by the model
|
||||||
|
new_text = new_text[len(prefix_text) :]
|
||||||
|
results.append((new_text, read_offset, len(input_ids) + 1))
|
||||||
|
else:
|
||||||
|
results.append(("", prefix_offset, read_offset))
|
||||||
|
return results
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
for n, p in self.model.named_parameters():
|
for n, p in self.model.named_parameters():
|
||||||
|
@ -345,7 +345,7 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
|||||||
"""Find the top n most likely tokens for a batch of generations."""
|
"""Find the top n most likely tokens for a batch of generations."""
|
||||||
top_n_tokens = torch.tensor(top_n_tokens)
|
top_n_tokens = torch.tensor(top_n_tokens)
|
||||||
if top_n_tokens.min() == 0:
|
if top_n_tokens.min() == 0:
|
||||||
return [], []
|
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||||
|
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
||||||
|
Loading…
Reference in New Issue
Block a user