mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add batched top-n-tokens to FlashCausalLM
This commit is contained in:
parent
0facd94738
commit
dbb92c20e7
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import itertools
|
import itertools
|
||||||
|
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
|
|||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
|
TopToken,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
@ -165,6 +167,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_chooser: HeterogeneousNextTokenChooser
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
|
||||||
# Number of blocks in this batch
|
# Number of blocks in this batch
|
||||||
blocks: int
|
blocks: int
|
||||||
@ -217,6 +220,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
@ -259,6 +263,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
@ -378,6 +383,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -417,6 +423,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
blocks = 0
|
blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
@ -443,6 +450,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
remaining_tokens = (
|
remaining_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
@ -518,6 +527,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -577,6 +587,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
@ -624,6 +635,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
top_n_tokens.extend(top_n_tokens)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots)
|
||||||
@ -666,6 +679,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -831,10 +845,14 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens, logprobs
|
||||||
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -933,6 +951,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -947,7 +967,25 @@ class FlashCausalLM(Model):
|
|||||||
seed,
|
seed,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
top_tokens = []
|
||||||
|
for token_id, token_logprob in zip(top_token_ids, top_token_logprobs):
|
||||||
|
tok_itm = token_id
|
||||||
|
top_tokens.append(
|
||||||
|
TopToken(
|
||||||
|
token_id=token_id,
|
||||||
|
token_logprob=token_logprob,
|
||||||
|
token_text=self.decode_token(
|
||||||
|
all_input_ids=all_input_ids + [tok_itm],
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)[0],
|
||||||
|
token_is_special=tok_itm in self.all_special_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
|
||||||
|
@ -230,11 +230,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = warper(input_ids, scores)
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
next_logprobs = torch.gather(
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
return next_ids, next_logprobs
|
return next_ids, next_logprobs, logprobs
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -342,6 +341,28 @@ class HeterogeneousSampling:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
||||||
|
"""Find the top n most likely tokens for a batch of generations."""
|
||||||
|
top_n_tokens = torch.tensor(top_n_tokens)
|
||||||
|
if top_n_tokens.min() == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Ensure top_n doesn't exceed vocab size
|
||||||
|
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
|
||||||
|
|
||||||
|
# Take the topk using the highest requested top_n_tokens.
|
||||||
|
top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True)
|
||||||
|
|
||||||
|
# Move all digits into a list at once to prevent multiple GPU syncs
|
||||||
|
top_indices = top_k.indices.tolist()
|
||||||
|
top_values = top_k.values.tolist()
|
||||||
|
|
||||||
|
return (
|
||||||
|
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
||||||
|
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_top_tokens(
|
def get_top_tokens(
|
||||||
requested_n: int,
|
requested_n: int,
|
||||||
logprobs,
|
logprobs,
|
||||||
@ -354,7 +375,8 @@ def get_top_tokens(
|
|||||||
if not requested_n:
|
if not requested_n:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
flat_scores = logprobs[-1]
|
# Dirty hack
|
||||||
|
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n = min(requested_n, flat_scores.size(-1))
|
top_n = min(requested_n, flat_scores.size(-1))
|
||||||
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
||||||
@ -368,14 +390,16 @@ def get_top_tokens(
|
|||||||
for tid_tensor in top_n_indices:
|
for tid_tensor in top_n_indices:
|
||||||
tid_item = tid_tensor[0].item()
|
tid_item = tid_tensor[0].item()
|
||||||
token_text, _, _ = decode_fn(
|
token_text, _, _ = decode_fn(
|
||||||
torch.cat([decoder_input_ids, tid_tensor]),
|
torch.cat([decoder_input_ids, tid_tensor])
|
||||||
|
if isinstance(decoder_input_ids, torch.Tensor)
|
||||||
|
else decoder_input_ids + [tid_item],
|
||||||
prefix_offset,
|
prefix_offset,
|
||||||
read_offset,
|
read_offset,
|
||||||
)
|
)
|
||||||
top_tokens.append(
|
top_tokens.append(
|
||||||
TopToken(
|
TopToken(
|
||||||
token_id=tid_item,
|
token_id=tid_item,
|
||||||
token_logprob=logprobs[-1, tid_tensor],
|
token_logprob=flat_scores[tid_tensor],
|
||||||
token_text=token_text,
|
token_text=token_text,
|
||||||
token_is_special=tid_item in special_tokens,
|
token_is_special=tid_item in special_tokens,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user