Add batched top-n-tokens to FlashCausalLM

This commit is contained in:
Vincent Brouwers 2023-07-25 14:17:25 +00:00 committed by Nicolas Patry
parent 0facd94738
commit dbb92c20e7
2 changed files with 70 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import math import math
import itertools import itertools
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens
import torch import torch
import torch.distributed import torch.distributed
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopToken,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -165,6 +167,7 @@ class FlashCausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
# Number of blocks in this batch # Number of blocks in this batch
blocks: int blocks: int
@ -217,6 +220,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -259,6 +263,7 @@ class FlashCausalLMBatch(Batch):
) )
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
@ -378,6 +383,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -417,6 +423,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
blocks = 0 blocks = 0
max_blocks = 0 max_blocks = 0
@ -443,6 +450,8 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_tokens = ( remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -518,6 +527,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -577,6 +587,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
@ -624,6 +635,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(top_n_tokens)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
@ -666,6 +679,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -831,10 +845,14 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, logprobs
)
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -933,6 +951,8 @@ class FlashCausalLM(Model):
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
next_token_ids, next_token_ids,
next_token_logprobs, next_token_logprobs,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -947,7 +967,25 @@ class FlashCausalLM(Model):
seed, seed,
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = []
for token_id, token_logprob in zip(top_token_ids, top_token_logprobs):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=self.decode_token(
all_input_ids=all_input_ids + [tok_itm],
prefix_offset=prefix_offset,
read_offset=read_offset,
)[0],
token_is_special=tok_itm in self.all_special_ids,
)
)
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)

View File

@ -230,11 +230,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores) scores = warper(input_ids, scores)
next_ids = self.choice(scores) next_ids = self.choice(scores)
next_logprobs = torch.gather( logprobs = torch.log_softmax(scores, -1)
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
).view(-1)
return next_ids, next_logprobs return next_ids, next_logprobs, logprobs
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -342,6 +341,28 @@ class HeterogeneousSampling:
return self return self
def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
"""Find the top n most likely tokens for a batch of generations."""
top_n_tokens = torch.tensor(top_n_tokens)
if top_n_tokens.min() == 0:
return [], []
# Ensure top_n doesn't exceed vocab size
top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))
# Take the topk using the highest requested top_n_tokens.
top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True)
# Move all digits into a list at once to prevent multiple GPU syncs
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
)
def get_top_tokens( def get_top_tokens(
requested_n: int, requested_n: int,
logprobs, logprobs,
@ -354,7 +375,8 @@ def get_top_tokens(
if not requested_n: if not requested_n:
return [] return []
flat_scores = logprobs[-1] # Dirty hack
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n = min(requested_n, flat_scores.size(-1)) top_n = min(requested_n, flat_scores.size(-1))
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k) # Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
@ -368,14 +390,16 @@ def get_top_tokens(
for tid_tensor in top_n_indices: for tid_tensor in top_n_indices:
tid_item = tid_tensor[0].item() tid_item = tid_tensor[0].item()
token_text, _, _ = decode_fn( token_text, _, _ = decode_fn(
torch.cat([decoder_input_ids, tid_tensor]), torch.cat([decoder_input_ids, tid_tensor])
if isinstance(decoder_input_ids, torch.Tensor)
else decoder_input_ids + [tid_item],
prefix_offset, prefix_offset,
read_offset, read_offset,
) )
top_tokens.append( top_tokens.append(
TopToken( TopToken(
token_id=tid_item, token_id=tid_item,
token_logprob=logprobs[-1, tid_tensor], token_logprob=flat_scores[tid_tensor],
token_text=token_text, token_text=token_text,
token_is_special=tid_item in special_tokens, token_is_special=tid_item in special_tokens,
) )