mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add batched top-n-tokens to FlashCausalLM
This commit is contained in:
parent
0facd94738
commit
dbb92c20e7
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import itertools
|
||||
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopToken,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
@ -165,6 +167,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Generation helpers
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
|
||||
# Number of blocks in this batch
|
||||
blocks: int
|
||||
@ -217,6 +220,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
@ -259,6 +263,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
@ -378,6 +383,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -417,6 +423,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets = []
|
||||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
blocks = 0
|
||||
max_blocks = 0
|
||||
@ -443,6 +450,8 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
@ -518,6 +527,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -577,6 +587,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
@ -624,6 +635,8 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
top_n_tokens.extend(top_n_tokens)
|
||||
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
cumulative_slots += len(batch.slots)
|
||||
@ -666,6 +679,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -831,10 +845,14 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
next_token_logits = out
|
||||
|
||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
||||
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, logprobs
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
@ -933,6 +951,8 @@ class FlashCausalLM(Model):
|
||||
batch.next_token_chooser.seeds,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -947,7 +967,25 @@ class FlashCausalLM(Model):
|
||||
seed,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = []
|
||||
for token_id, token_logprob in zip(top_token_ids, top_token_logprobs):
|
||||
tok_itm = token_id
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=token_id,
|
||||
token_logprob=token_logprob,
|
||||
token_text=self.decode_token(
|
||||
all_input_ids=all_input_ids + [tok_itm],
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)[0],
|
||||
token_is_special=tok_itm in self.all_special_ids,
|
||||
)
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
|
||||
|
@ -230,11 +230,10 @@ class HeterogeneousNextTokenChooser:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
next_logprobs = torch.gather(
|
||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||
).view(-1)
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
return next_ids, next_logprobs
|
||||
return next_ids, next_logprobs, logprobs
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
@ -342,6 +341,28 @@ class HeterogeneousSampling:
|
||||
return self
|
||||
|
||||
|
||||
def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
||||
"""Find the top n most likely tokens for a batch of generations."""
|
||||
top_n_tokens = torch.tensor(top_n_tokens)
|
||||
if top_n_tokens.min() == 0:
|
||||
return [], []
|
||||
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
||||
|
||||
# Take the topk using the highest requested top_n_tokens.
|
||||
top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True)
|
||||
|
||||
# Move all digits into a list at once to prevent multiple GPU syncs
|
||||
top_indices = top_k.indices.tolist()
|
||||
top_values = top_k.values.tolist()
|
||||
|
||||
return (
|
||||
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
||||
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
||||
)
|
||||
|
||||
|
||||
def get_top_tokens(
|
||||
requested_n: int,
|
||||
logprobs,
|
||||
@ -354,7 +375,8 @@ def get_top_tokens(
|
||||
if not requested_n:
|
||||
return []
|
||||
|
||||
flat_scores = logprobs[-1]
|
||||
# Dirty hack
|
||||
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n = min(requested_n, flat_scores.size(-1))
|
||||
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
||||
@ -368,14 +390,16 @@ def get_top_tokens(
|
||||
for tid_tensor in top_n_indices:
|
||||
tid_item = tid_tensor[0].item()
|
||||
token_text, _, _ = decode_fn(
|
||||
torch.cat([decoder_input_ids, tid_tensor]),
|
||||
torch.cat([decoder_input_ids, tid_tensor])
|
||||
if isinstance(decoder_input_ids, torch.Tensor)
|
||||
else decoder_input_ids + [tid_item],
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=tid_item,
|
||||
token_logprob=logprobs[-1, tid_tensor],
|
||||
token_logprob=flat_scores[tid_tensor],
|
||||
token_text=token_text,
|
||||
token_is_special=tid_item in special_tokens,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user